diff --git a/projects/hipcub/hipcub/include/hipcub/backend/rocprim/device/device_segmented_reduce.hpp b/projects/hipcub/hipcub/include/hipcub/backend/rocprim/device/device_segmented_reduce.hpp index fc2d02fed48..749960012c3 100644 --- a/projects/hipcub/hipcub/include/hipcub/backend/rocprim/device/device_segmented_reduce.hpp +++ b/projects/hipcub/hipcub/include/hipcub/backend/rocprim/device/device_segmented_reduce.hpp @@ -51,23 +51,24 @@ namespace detail { template -inline hipError_t launch_segmented_arg_minmax(::rocprim::detail::target_arch arch, - InputIterator input, - OutputIterator output, - OffsetIterator begin_offsets, - OffsetIterator end_offsets, - BinaryFunction reduce_op, - ResultType initial_value, - ResultType empty_value, - dim3 grid, - dim3 block, - size_t shmem, - hipStream_t stream) +inline hipError_t launch_segmented_arg_minmax(::rocprim::detail::target current_target, + InputIterator input, + OutputIterator output, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + BinaryFunction reduce_op, + ResultType initial_value, + ResultType empty_value, + dim3 grid, + dim3 block, + size_t shmem, + hipStream_t stream) { auto kernel = [=](auto arch_config) { @@ -103,7 +104,12 @@ inline hipError_t launch_segmented_arg_minmax(::rocprim::detail::target_arch arc } }; - return ::rocprim::detail::execute_launch_plan(arch, kernel, grid, block, shmem, stream); + return ::rocprim::detail::execute_launch_plan(current_target, + kernel, + grid, + block, + shmem, + stream); } /// Dispatch function similar to \p rocprim::segmented_reduce but writes \p empty_value for empty @@ -129,17 +135,24 @@ inline hipError_t segmented_arg_minmax(void* temporary_storage, using input_type = typename std::iterator_traits::value_type; using result_type = ::rocprim::accumulator_t; - using config = ::rocprim::detail::wrapped_reduce_config; + using selector = ::rocprim::detail::segmented_reduce_config_selector; ::rocprim::detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); + hipError_t result = ::rocprim::detail::host_target_arch(stream, target_arch); if(result != hipSuccess) { return result; } - const ::rocprim::detail::reduce_config_params params - = ::rocprim::detail::dispatch_target_arch(target_arch); + ::rocprim::detail::gpu target_gpu; + result = ::rocprim::detail::host_target_gpu(stream, target_gpu); + if(result != hipSuccess) + { + return result; + } + + const ::rocprim::detail::target current_target(target_arch, target_gpu); + const auto params = ::rocprim::detail::get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; if(temporary_storage == nullptr) @@ -160,18 +173,18 @@ inline hipError_t segmented_arg_minmax(void* temporary_storage, start = std::chrono::high_resolution_clock::now(); } ROCPRIM_RETURN_ON_ERROR( - launch_segmented_arg_minmax(target_arch, - input, - output, - begin_offsets, - end_offsets, - reduce_op, - static_cast(initial_value), - static_cast(empty_value), - dim3(segments), - dim3(block_size), - 0, - stream)); + launch_segmented_arg_minmax(current_target, + input, + output, + begin_offsets, + end_offsets, + reduce_op, + static_cast(initial_value), + static_cast(empty_value), + dim3(segments), + dim3(block_size), + 0, + stream)); HIPCUB_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_arg_minmax", segments, start); return hipSuccess; diff --git a/projects/rocprim/CHANGELOG.md b/projects/rocprim/CHANGELOG.md index 8232e505acd..940eaac0700 100644 --- a/projects/rocprim/CHANGELOG.md +++ b/projects/rocprim/CHANGELOG.md @@ -2,6 +2,12 @@ Full documentation for rocPRIM is available at [https://rocm.docs.amd.com/projects/rocPRIM/en/latest/](https://rocm.docs.amd.com/projects/rocPRIM/en/latest/). +## rocPRIM x.y.z for ROCm 8.0 + +### Optimizations + +* Updated config system to pick better fallback configs for untuned GPUs. + ## rocPRIM 4.2.0 for ROCm 7.2 ### Added diff --git a/projects/rocprim/benchmark/benchmark_device_batch_memcpy.cpp b/projects/rocprim/benchmark/benchmark_device_batch_memcpy.cpp index 2f11c1112de..310c4eed8c9 100644 --- a/projects/rocprim/benchmark/benchmark_device_batch_memcpy.cpp +++ b/projects/rocprim/benchmark/benchmark_device_batch_memcpy.cpp @@ -133,18 +133,23 @@ BatchMemcpyData prepare_data(hipStream_t stre BatchMemcpyData result; - using config - = rocprim::detail::wrapped_batch_memcpy_config; + using Selector = rocprim::detail::batch_memcpy_config_selector; rocprim::detail::target_arch target_arch; - hipError_t success = rocprim::detail::host_target_arch(stream, target_arch); + hipError_t success = host_target_arch(stream, target_arch); + + rocprim::detail::gpu target_gpu; + success = host_target_gpu(stream, target_gpu); + if(success != hipSuccess) { return result; } - const rocprim::detail::batch_memcpy_config_params params - = rocprim::detail::dispatch_target_arch(target_arch); + const rocprim::detail::target get_target(target_arch, target_gpu); + + const auto params + = rocprim::detail::get_config(rocprim::default_config{}, get_target); const int32_t wlev_min_size = params.wlev_size_threshold; const int32_t blev_min_size = params.blev_size_threshold; diff --git a/projects/rocprim/benchmark/benchmark_device_histogram.parallel.hpp b/projects/rocprim/benchmark/benchmark_device_histogram.parallel.hpp index 1a88525565e..b66821fb76c 100644 --- a/projects/rocprim/benchmark/benchmark_device_histogram.parallel.hpp +++ b/projects/rocprim/benchmark/benchmark_device_histogram.parallel.hpp @@ -195,6 +195,21 @@ struct device_histogram_benchmark : public benchmark_utils::autotune_interface + ",cfg:" + config_name() + "}"); } + template + void clear_other_caches() + { + ( + [](auto u) + { + using U = decltype(u); + if(!std::is_same_v) + { + input_cache::instance().clear(); + } + }(Args{}), + ...); + } + void run(benchmark_utils::state&& state) override { const auto& stream = state.stream; @@ -220,6 +235,16 @@ struct device_histogram_benchmark : public benchmark_utils::autotune_interface }; }; + // Clear caches for other types that are either empty or already done. + clear_other_caches(); + const std::size_t size = bytes / Channels; size_t temporary_storage_bytes = 0; diff --git a/projects/rocprim/benchmark/benchmark_device_transform.parallel.hpp b/projects/rocprim/benchmark/benchmark_device_transform.parallel.hpp index 7a0523e0864..1aa283197d2 100644 --- a/projects/rocprim/benchmark/benchmark_device_transform.parallel.hpp +++ b/projects/rocprim/benchmark/benchmark_device_transform.parallel.hpp @@ -65,7 +65,6 @@ template struct device_transform_benchmark : public benchmark_utils::autotune_interface { - std::string name() const override { @@ -122,13 +121,15 @@ struct device_transform_benchmark : public benchmark_utils::autotune_interface { const auto launch = [&] { + using Selector = rocprim::detail::transform_config_selector; auto transform_op = [](T v) { return v + T(5); }; - return rocprim::detail::transform_impl(d_input.get(), - d_output.get(), - size, - transform_op, - stream, - debug_synchronous); + return rocprim::detail::transform_impl( + d_input.get(), + d_output.get(), + size, + transform_op, + stream, + debug_synchronous); }; state.run([&] { HIP_CHECK(launch()); }); diff --git a/projects/rocprim/rocprim/include/rocprim/config.hpp b/projects/rocprim/rocprim/include/rocprim/config.hpp index 50a65505db7..2a40cdd3544 100644 --- a/projects/rocprim/rocprim/include/rocprim/config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/config.hpp @@ -152,88 +152,100 @@ #if !defined(ROCPRIM_THREAD_STORE_USE_CACHE_MODIFIERS) #define ROCPRIM_THREAD_STORE_USE_CACHE_MODIFIERS 1 #endif - #define IS_CDNA3() \ - __builtin_amdgcn_processor_is("gfx942") || __builtin_amdgcn_processor_is("gfx950") \ - || __builtin_amdgcn_processor_is("gfx9-4-generic") - #define IS_CDNA2() __builtin_amdgcn_processor_is("gfx90a") - #define IS_CDNA1() __builtin_amdgcn_processor_is("gfx908") - #define IS_GCN5() \ - __builtin_amdgcn_processor_is("gfx900") || __builtin_amdgcn_processor_is("gfx902") \ - || __builtin_amdgcn_processor_is("gfx904") || __builtin_amdgcn_processor_is("gfx906") \ - || __builtin_amdgcn_processor_is("gfx90c") \ - || __builtin_amdgcn_processor_is("gfx9-generic") - #define IS_RDNA4() \ - __builtin_amdgcn_processor_is("gfx1200") || __builtin_amdgcn_processor_is("gfx1201") \ - || __builtin_amdgcn_processor_is("gfx12-generic") // TODO: Re-enable gfx1250 when supported by compiler - #define IS_RDNA3() \ - __builtin_amdgcn_processor_is("gfx1100") || __builtin_amdgcn_processor_is("gfx1101") \ - || __builtin_amdgcn_processor_is("gfx1102") \ - || __builtin_amdgcn_processor_is("gfx1103") \ - || __builtin_amdgcn_processor_is("gfx1152") \ - || __builtin_amdgcn_processor_is("gfx1153") \ - || __builtin_amdgcn_processor_is("gfx11-generic") - #define IS_RDNA2() \ - __builtin_amdgcn_processor_is("gfx1030") || __builtin_amdgcn_processor_is("gfx1031") \ - || __builtin_amdgcn_processor_is("gfx1032") \ - || __builtin_amdgcn_processor_is("gfx1033") \ - || __builtin_amdgcn_processor_is("gfx1034") \ - || __builtin_amdgcn_processor_is("gfx1035") \ - || __builtin_amdgcn_processor_is("gfx1036") \ - || __builtin_amdgcn_processor_is("gfx10-3-generic") - #define IS_RDNA1() \ - __builtin_amdgcn_processor_is("gfx1010") || __builtin_amdgcn_processor_is("gfx1011") \ - || __builtin_amdgcn_processor_is("gfx1012") \ - || __builtin_amdgcn_processor_is("gfx1013") \ - || __builtin_amdgcn_processor_is("gfx10-1-generic") - #define IS_GCN3() \ - __builtin_amdgcn_processor_is("gfx801") || __builtin_amdgcn_processor_is("gfx802") \ - || __builtin_amdgcn_processor_is("gfx803") || __builtin_amdgcn_processor_is("gfx805") \ - || __builtin_amdgcn_processor_is("gfx810") + #define ROCPRIM_IS_CDNA3() \ + (__builtin_amdgcn_processor_is("gfx942") || __builtin_amdgcn_processor_is("gfx950") \ + || __builtin_amdgcn_processor_is("gfx9-4-generic")) + #define ROCPRIM_IS_CDNA2() (__builtin_amdgcn_processor_is("gfx90a")) + #define ROCPRIM_IS_CDNA1() (__builtin_amdgcn_processor_is("gfx908")) + #define ROCPRIM_IS_GCN5() \ + (__builtin_amdgcn_processor_is("gfx900") || __builtin_amdgcn_processor_is("gfx902") \ + || __builtin_amdgcn_processor_is("gfx904") || __builtin_amdgcn_processor_is("gfx906") \ + || __builtin_amdgcn_processor_is("gfx90c") \ + || __builtin_amdgcn_processor_is("gfx9-generic")) + #define ROCPRIM_IS_RDNA4() \ + (__builtin_amdgcn_processor_is("gfx1200") || __builtin_amdgcn_processor_is("gfx1201") \ + || __builtin_amdgcn_processor_is( \ + "gfx12-generic")) // TODO: Re-enable gfx1250 when supported by compiler + #define ROCPRIM_IS_RDNA3() \ + (__builtin_amdgcn_processor_is("gfx1100") || __builtin_amdgcn_processor_is("gfx1101") \ + || __builtin_amdgcn_processor_is("gfx1102") || __builtin_amdgcn_processor_is("gfx1103") \ + || __builtin_amdgcn_processor_is("gfx1150") || __builtin_amdgcn_processor_is("gfx1151") \ + || __builtin_amdgcn_processor_is("gfx1152") || __builtin_amdgcn_processor_is("gfx1153") \ + || __builtin_amdgcn_processor_is("gfx11-generic")) + #define ROCPRIM_IS_RDNA2() \ + (__builtin_amdgcn_processor_is("gfx1030") || __builtin_amdgcn_processor_is("gfx1031") \ + || __builtin_amdgcn_processor_is("gfx1032") || __builtin_amdgcn_processor_is("gfx1033") \ + || __builtin_amdgcn_processor_is("gfx1034") || __builtin_amdgcn_processor_is("gfx1035") \ + || __builtin_amdgcn_processor_is("gfx1036") \ + || __builtin_amdgcn_processor_is("gfx10-3-generic")) + #define ROCPRIM_IS_RDNA1() \ + (__builtin_amdgcn_processor_is("gfx1010") || __builtin_amdgcn_processor_is("gfx1011") \ + || __builtin_amdgcn_processor_is("gfx1012") || __builtin_amdgcn_processor_is("gfx1013") \ + || __builtin_amdgcn_processor_is("gfx10-1-generic")) + #define ROCPRIM_IS_GCN3() \ + (__builtin_amdgcn_processor_is("gfx801") || __builtin_amdgcn_processor_is("gfx802") \ + || __builtin_amdgcn_processor_is("gfx803") || __builtin_amdgcn_processor_is("gfx805") \ + || __builtin_amdgcn_processor_is("gfx810")) + #define ROCPRIM_IS_GENERIC() \ + (__builtin_amdgcn_processor_is("gfx9-4-generic") \ + || __builtin_amdgcn_processor_is("gfx9-generic") \ + || __builtin_amdgcn_processor_is("gfx11-generic") \ + || __builtin_amdgcn_processor_is("gfx10-3-generic") \ + || __builtin_amdgcn_processor_is("gfx10-1-generic") \ + || __builtin_amdgcn_processor_is("gfx12-generic")) #else #if defined(ROCPRIM_TARGET_CDNA3) - #define IS_CDNA3() 1 + #define ROCPRIM_IS_CDNA3() 1 #else - #define IS_CDNA3() 0 + #define ROCPRIM_IS_CDNA3() 0 #endif #if defined(ROCPRIM_TARGET_CDNA2) - #define IS_CDNA2() 1 + #define ROCPRIM_IS_CDNA2() 1 #else - #define IS_CDNA2() 0 + #define ROCPRIM_IS_CDNA2() 0 #endif #if defined(ROCPRIM_TARGET_CDNA1) - #define IS_CDNA1() 1 + #define ROCPRIM_IS_CDNA1() 1 #else - #define IS_CDNA1() 0 + #define ROCPRIM_IS_CDNA1() 0 #endif #if defined(ROCPRIM_TARGET_GCN5) - #define IS_GCN5() 1 + #define ROCPRIM_IS_GCN5() 1 #else - #define IS_GCN5() 0 + #define ROCPRIM_IS_GCN5() 0 #endif #if defined(ROCPRIM_TARGET_RDNA4) - #define IS_RDNA4() 1 + #define ROCPRIM_IS_RDNA4() 1 #else - #define IS_RDNA4() 0 + #define ROCPRIM_IS_RDNA4() 0 #endif #if defined(ROCPRIM_TARGET_RDNA3) - #define IS_RDNA3() 1 + #define ROCPRIM_IS_RDNA3() 1 #else - #define IS_RDNA3() 0 + #define ROCPRIM_IS_RDNA3() 0 #endif #if defined(ROCPRIM_TARGET_RDNA2) - #define IS_RDNA2() 1 + #define ROCPRIM_IS_RDNA2() 1 #else - #define IS_RDNA2() 0 + #define ROCPRIM_IS_RDNA2() 0 #endif #if defined(ROCPRIM_TARGET_RDNA1) - #define IS_RDNA1() 1 + #define ROCPRIM_IS_RDNA1() 1 #else - #define IS_RDNA1() 0 + #define ROCPRIM_IS_RDNA1() 0 #endif #if defined(ROCPRIM_TARGET_GCN3) - #define IS_GCN3() 1 + #define ROCPRIM_IS_GCN3() 1 #else - #define IS_GCN3() 0 + #define ROCPRIM_IS_GCN3() 0 + #endif + + #if defined(__gfx9_generic__) || defined(__gfx9_4_generic__) || defined(__gfx10_1_generic__) \ + || defined(__gfx10_3_generic__) || defined(__gfx11_generic__) \ + || defined(__gfx12_generic__) + #define ROCPRIM_IS_GENERIC() 1 + #else + #define ROCPRIM_IS_GENERIC() 0 #endif #if !defined(ROCPRIM_THREAD_LOAD_USE_CACHE_MODIFIERS) @@ -267,7 +279,7 @@ #define ROCPRIM_DETAIL_HAS_DPP 1 #endif -#if (!defined(ROCPRIM_DISABLE_DPP) || ROCPRIM_DISABLE_DPP == 0) \ +#if(!defined(ROCPRIM_DISABLE_DPP) || ROCPRIM_DISABLE_DPP == 0) \ && (defined(ROCPRIM_DETAIL_HAS_DPP) && ROCPRIM_DETAIL_HAS_DPP == 1) #define ROCPRIM_DETAIL_USE_DPP 1 #else @@ -292,7 +304,7 @@ /// Quad size (group of 4 threads) #define ROCPRIM_QUAD_SIZE 4u -#if (defined(_MSC_VER) && !defined(__clang__)) || (defined(__GNUC__) && !defined(__clang__)) +#if(defined(_MSC_VER) && !defined(__clang__)) || (defined(__GNUC__) && !defined(__clang__)) #define ROCPRIM_UNROLL #define ROCPRIM_NO_UNROLL #else diff --git a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp index e20a65b3e52..34e66f5f26b 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp @@ -26,6 +26,8 @@ #include #include #include +#include +#include #include #include @@ -159,6 +161,7 @@ using default_or_custom_config = typename std::conditional::value, Default, Config>::type; #ifndef DOXYGEN_SHOULD_SKIP_THIS +// NOTE: When adding a new target_arch also add it to gen_from_target_arch and get_target_arch_from_name enum class target_arch : unsigned int { // This must be zero, to initialize the device -> architecture cache @@ -170,9 +173,16 @@ enum class target_arch : unsigned int gfx90a = 910, gfx942 = 942, gfx950 = 950, + gfx1010 = 1010, + gfx1011 = 1011, + gfx1012 = 1012, gfx1030 = 1030, gfx1100 = 1100, + gfx1101 = 1101, gfx1102 = 1102, + gfx1103 = 1103, + gfx1150 = 1150, + gfx1151 = 1151, gfx1152 = 1152, gfx1153 = 1153, gfx1200 = 1200, @@ -181,6 +191,141 @@ enum class target_arch : unsigned int }; #endif // DOXYGEN_SHOULD_SKIP_THIS +enum class rep +{ + amdgcn, + spirv, +}; + +enum class gen +{ + unknown, + gcn3, + gcn5, + cdna1, + cdna2, + cdna3, + cdna4, + rdna1, + rdna2, + rdna3, + rdna4, +}; + +enum class gpu +{ + generic, + v620, + rx6900, + rx7900, + rx9060, + rx9070, + mi50, + mi100, + mi210, + mi300x, + mi300a, + mi308x, + mi325x, + mi350x +}; + +constexpr gen gen_from_target_arch(target_arch i) +{ + switch(i) + { + case target_arch::gfx803: return gen::gcn3; + case target_arch::gfx900: + case target_arch::gfx906: return gen::gcn5; + case target_arch::gfx908: return gen::cdna1; + case target_arch::gfx90a: return gen::cdna2; + case target_arch::gfx942: return gen::cdna3; + case target_arch::gfx950: return gen::cdna4; + case target_arch::gfx1010: + case target_arch::gfx1011: + case target_arch::gfx1012: return gen::rdna1; + case target_arch::gfx1030: return gen::rdna2; + case target_arch::gfx1100: + case target_arch::gfx1101: + case target_arch::gfx1102: + case target_arch::gfx1103: + case target_arch::gfx1150: + case target_arch::gfx1151: + case target_arch::gfx1152: + case target_arch::gfx1153: return gen::rdna3; + case target_arch::gfx1200: + case target_arch::gfx1201: return gen::rdna4; + case target_arch::unknown: + case target_arch::invalid: return gen::unknown; + } +} + +constexpr std::tuple target_gpu_names[] = { + std::make_tuple("MI350X", gpu::mi350x), + std::make_tuple("MI325X", gpu::mi325x), + std::make_tuple("MI308X", gpu::mi308x), + std::make_tuple("MI300A", gpu::mi300a), + std::make_tuple("MI300X", gpu::mi300x), + std::make_tuple("MI210", gpu::mi210), + std::make_tuple("MI100", gpu::mi100), + std::make_tuple("RX 9060", gpu::rx9060), + std::make_tuple("RX 9070", gpu::rx9070), + std::make_tuple("V620", gpu::v620), + std::make_tuple("RX 7900", gpu::rx7900), + std::make_tuple("RX 6900", gpu::rx6900), +}; + +// TODO: Remove comp_target when adopting C++20 and dropping C++17 support. +// comp_target exists, because target can not be passed as a template variable before C++20. +template +struct comp_target +{ + static constexpr gen g = g_; + static constexpr target_arch i = i_; + static constexpr gpu s = s_; + static constexpr rep r = r_; +}; + +// Macro to have a singular place for conversion, limited by C++17. +#define TARGET_TO_COMP_TARGET(CT) comp_target<(CT).g, (CT).i, (CT).s, (CT).r> + +struct target +{ + gen g; + target_arch i; + gpu s; + rep r; + + constexpr target(target_arch i, gpu s = gpu::generic, rep r = rep::amdgcn) + : g(gen_from_target_arch(i)), i(i), s(s), r(r){}; + + constexpr target(gen g = gen::unknown, + target_arch i = target_arch::unknown, + gpu s = gpu::generic, + rep r = rep::amdgcn) + : g(g), i(i), s(s), r(r){}; + + template + constexpr target(CompTarget) + : g(CompTarget::g), i(CompTarget::i), s(CompTarget::s), r(CompTarget::r) + {} + + constexpr bool operator==(target other) const + { + return g == other.g && i == other.i && s == other.s && r == other.r; + } +}; + +template +struct comp_targets +{ + template + static constexpr void for_each(F f) + { + (f(Ts{}), ...); + } +}; + /** * \brief Checks if the first `n` characters of `rhs` are equal to `lhs` * @@ -206,78 +351,54 @@ constexpr bool prefix_equals(const char* lhs, const char* rhs, std::size_t n) return i == n && *lhs == '\0'; } -struct target_arch_descriptor -{ - target_arch arch; - const char *arch_name; -}; - -#define X(ID) target_arch_descriptor{target_arch::ID, #ID} -constexpr auto target_arch_descriptors = std::array{ - X(gfx803), - X(gfx900), - X(gfx906), - X(gfx908), - X(gfx90a), - X(gfx942), - X(gfx950), - X(gfx1030), - X(gfx1100), - X(gfx1102), - X(gfx1152), - X(gfx1153), - X(gfx1200), - X(gfx1201), -}; -#undef X - -constexpr target_arch get_target_arch_from_name(const char* const arch_name, const std::size_t n) -{ - for (const auto& desc : target_arch_descriptors) - { - if(prefix_equals(desc.arch_name, arch_name, n)) - { - return desc.arch; - } +#define ROCPRIM_RETURN_IF_ARCH(ID) \ + if(prefix_equals(#ID, arch_name, n)) \ + { \ + return target_arch::ID; \ } - return target_arch::unknown; -} - -template -constexpr void for_each_arch_impl(F&& f, std::index_sequence) +constexpr target_arch get_target_arch_from_name(const char* const arch_name, const std::size_t n) { - (f(std::integral_constant{}), ...); -} + ROCPRIM_RETURN_IF_ARCH(gfx803); + ROCPRIM_RETURN_IF_ARCH(gfx900); + ROCPRIM_RETURN_IF_ARCH(gfx906); + ROCPRIM_RETURN_IF_ARCH(gfx908); + ROCPRIM_RETURN_IF_ARCH(gfx90a); + ROCPRIM_RETURN_IF_ARCH(gfx942); + ROCPRIM_RETURN_IF_ARCH(gfx950); + ROCPRIM_RETURN_IF_ARCH(gfx1010); + ROCPRIM_RETURN_IF_ARCH(gfx1011); + ROCPRIM_RETURN_IF_ARCH(gfx1012); + ROCPRIM_RETURN_IF_ARCH(gfx1030); + ROCPRIM_RETURN_IF_ARCH(gfx1100); + ROCPRIM_RETURN_IF_ARCH(gfx1101); + ROCPRIM_RETURN_IF_ARCH(gfx1102); + ROCPRIM_RETURN_IF_ARCH(gfx1103); + ROCPRIM_RETURN_IF_ARCH(gfx1150); + ROCPRIM_RETURN_IF_ARCH(gfx1151); + ROCPRIM_RETURN_IF_ARCH(gfx1152); + ROCPRIM_RETURN_IF_ARCH(gfx1153); + ROCPRIM_RETURN_IF_ARCH(gfx1200); + ROCPRIM_RETURN_IF_ARCH(gfx1201); -template -constexpr void for_each_arch(F&& f) -{ - for_each_arch_impl(std::forward(f), - std::make_index_sequence{}); + return target_arch::unknown; } +#undef ROCPRIM_RETURN_IF_ARCH -constexpr arch::wavefront::target arch_wavefront_size(const target_arch target_arch) +constexpr arch::wavefront::target gen_wavefront_size(const gen gen) { - switch(target_arch) + switch(gen) { - case target_arch::unknown: return arch::wavefront::get_target(); - case target_arch::gfx803: return arch::wavefront::target::size64; - case target_arch::gfx900: return arch::wavefront::target::size64; - case target_arch::gfx906: return arch::wavefront::target::size64; - case target_arch::gfx908: return arch::wavefront::target::size64; - case target_arch::gfx90a: return arch::wavefront::target::size64; - case target_arch::gfx942: return arch::wavefront::target::size64; - case target_arch::gfx950: return arch::wavefront::target::size64; - case target_arch::gfx1030: return arch::wavefront::target::size32; - case target_arch::gfx1100: return arch::wavefront::target::size32; - case target_arch::gfx1102: return arch::wavefront::target::size32; - case target_arch::gfx1152: return arch::wavefront::target::size32; - case target_arch::gfx1153: return arch::wavefront::target::size32; - case target_arch::gfx1200: return arch::wavefront::target::size32; - case target_arch::gfx1201: return arch::wavefront::target::size32; - - // Unreachable - case target_arch::invalid: return arch::wavefront::target::dynamic; + case gen::unknown: return arch::wavefront::get_target(); + case gen::gcn3: + case gen::gcn5: + case gen::cdna1: + case gen::cdna2: + case gen::cdna3: + case gen::cdna4: return arch::wavefront::target::size64; + case gen::rdna1: + case gen::rdna2: + case gen::rdna3: + case gen::rdna4: return arch::wavefront::target::size32; } } @@ -300,221 +421,185 @@ constexpr target_arch device_target_arch() #endif } -template -struct default_config_selector -{ - static constexpr unsigned int block_size - = Config::template architecture_config::params.kernel_config.block_size; -}; - -template -struct non_blev_batch_memcpy_config_selector -{ - static constexpr unsigned int block_size = Config::template architecture_config::params - .non_blev_batch_memcpy_kernel_config.block_size; -}; - -template -struct blev_batch_memcpy_config_selector +template +struct launch_plan { - static constexpr unsigned int block_size = Config::template architecture_config::params - .blev_batch_memcpy_kernel_config.block_size; -}; + using kernel_type = void (*)(Kernel); + kernel_type kernel; + Kernel device_callback; -template -struct histogram_config_selector -{ - static constexpr unsigned int block_size - = Config::template architecture_config::params.histogram_config.block_size; + void launch(dim3 grid_size, dim3 block_size, size_t shared_mem, hipStream_t stream) const + { + hipLaunchKernelGGL(HIP_KERNEL_NAME(kernel), + grid_size, + block_size, + shared_mem, + stream, + device_callback); + } }; -template -struct histogram_global_config_selector +template +constexpr target most_common_config(target target_current) { - static constexpr unsigned int block_size - = Config::template architecture_config::params.histogram_global_config.block_size; -}; + // Takes unknown as default. + target ret{}; + Targets::for_each( + [&](auto t) + { + // Skip unknown target for picking. + if(!(target{} == t)) + { + constexpr target_arch Arch = t.i; + constexpr gpu GPU = t.s; + constexpr gen Gen = t.g; + + // Update `ret` if the candidate `t` matches more specifically than the current `ret`. + // Priority order: prefer exact GPU match first; otherwise allow an Arch match (if GPU differs); + // finally allow a Gen match (if both Arch and GPU differ). This ensures we progressively + // refine the fallback from generic -> generation -> arch -> exact GPU. + if((GPU == target_current.s) + || (Arch == target_current.i + && (target_current.s != ret.s || ret.s == gpu::generic)) + || (Gen == target_current.g + && ((target_current.s != ret.s || ret.s == gpu::generic) + && (target_current.i != ret.i || ret.i == target_arch::unknown)))) + { + ret = target{t}; + } + } + }); -template -struct merge_oddeven_config_selector -{ - static constexpr unsigned int block_size - = Config::template architecture_config::params.merge_oddeven_config.block_size; -}; + return ret; +} -template -struct merge_mergepath_partition_config_selector +template +constexpr typename Selector::param_type default_select_config(target t) { - static constexpr unsigned int block_size = Config::template architecture_config::params - .merge_mergepath_partition_config.block_size; -}; + using Targets = typename Selector::targets; + using Params = typename Selector::param_type; -template -struct merge_mergepath_config_selector -{ - static constexpr unsigned int block_size - = Config::template architecture_config::params.merge_mergepath_config.block_size; -}; + const target target_config = most_common_config(t); -template -struct radix_sort_config_selector -{ - static constexpr unsigned int block_size - = Config::template architecture_config::params.block_size; -}; + Params params{}; -template -struct radix_sort_onesweep_histogram_config_selector -{ - static constexpr unsigned int block_size - = Config::template architecture_config::params.histogram.block_size; -}; + Targets::for_each( + [&](auto candidate) + { + if(target{candidate} == target_config) + { + params = Selector{candidate}.params; + } + }); -template -struct radix_sort_onesweep_sort_config_selector -{ - static constexpr unsigned int block_size - = Config::template architecture_config::params.sort.block_size; -}; + return params; +} -template -struct segmented_radix_sort_warp_sort_small_config_selector +template +constexpr typename Selector::param_type get_config(Config config, target t) { - static constexpr unsigned int block_size - = Config::template architecture_config::params.warp_sort_config.block_size_small; + if constexpr(std::is_same_v) + { + return default_select_config(t); + } + else + { + return config; + } }; -template -struct segmented_radix_sort_warp_sort_meduim_config_selector +template +struct target_config { - static constexpr unsigned int block_size - = Config::template architecture_config::params.warp_sort_config.block_size_medium; + constexpr static target config_target = target{Target{}}; + constexpr static auto params = get_config(Config{}, config_target); + constexpr static auto wavefront = gen_wavefront_size(Target::g); }; -template -struct target_config +template +struct default_config_static_selector { - constexpr static auto params = Config::template architecture_config::params; - constexpr static auto wavefront = arch_wavefront_size(Arch); - constexpr static auto arch = Arch; + static constexpr auto block_size + = target_config::params.kernel_config.block_size; }; // trampoline_kernel that is fully specialized at compile-time for a single GPU architecture. -// By instantiating this template once per supported `target_arch`,the correct tuned config +// By instantiating this template once per supported `target_arch`, the correct tuned config // will be derived from the template. -template + class Target, + template class LaunchSelector> -ROCPRIM_KERNEL __launch_bounds__((LaunchSelector::block_size)) +ROCPRIM_KERNEL __launch_bounds__((LaunchSelector::block_size)) void trampoline_kernel(Kernel kernel) { - using ArchConfig = target_config; + using ArchConfig = target_config; #if !defined(ROCPRIM_TARGET_SPIRV) || ROCPRIM_TARGET_SPIRV == 0 - if constexpr(Arch == device_target_arch()) -#endif + using Targets = typename Selector::targets; + // If the arch does not exist in the Targets it should run the arch for the most_common_config. + constexpr target device_arch_target = most_common_config(target(device_target_arch())); + // If the build time arch from device_target_arch is a generic arch it is not the same as the runtime arch. + if constexpr(Target::i == device_arch_target.i) { kernel(ArchConfig{}); } -} - -template -struct launch_plan -{ - using kernel_type = void (*)(Kernel); - kernel_type kernel; - Kernel device_callback; - - void launch(dim3 grid_size, dim3 block_size, size_t shared_mem, hipStream_t stream) const + else if constexpr(ROCPRIM_IS_GENERIC()) { - hipLaunchKernelGGL(HIP_KERNEL_NAME(kernel), - grid_size, - block_size, - shared_mem, - stream, - device_callback); + kernel(ArchConfig{}); } -}; + else + { + __builtin_unreachable(); + } +#else + kernel(ArchConfig{}); +#endif +} template class LaunchSelector = default_config_selector> -auto make_launch_plan(target_arch arch, Kernel kernel) -> launch_plan + class ConfigSelector, + template class LaunchSelector = default_config_static_selector, + class Kernel> +auto make_launch_plan(target target_current, Kernel kernel) -> launch_plan { + using Targets = typename ConfigSelector::targets; + std::optional tuned_kernel = std::nullopt; - for_each_arch( - [&](auto arch_tag) - { - if(arch_tag != arch || tuned_kernel) - return; + const target target_config = most_common_config(target_current); - tuned_kernel = trampoline_kernel; + // The target config is always in Targets. + Targets::for_each( + [&](auto t) + { + if(target{t} == target_config) + { + tuned_kernel = trampoline_kernel; + } }); - if(!tuned_kernel) - { - tuned_kernel = trampoline_kernel; - } - return {tuned_kernel.value(), kernel}; } -// Host-side helper running at run-time, picking the trampoline_kernel whose template -// argument `Arch` matches the actual GPU we are executing on. template class LaunchSelector = default_config_selector> -hipError_t execute_launch_plan(target_arch arch, - Kernel kernel, - dim3 grid_size, - dim3 block_size, - size_t shmem, - hipStream_t stream) -{ - const auto launch_plan = make_launch_plan(arch, kernel); + class ConfigSelector, + template class LaunchSelector = default_config_static_selector, + class Kernel> +hipError_t execute_launch_plan( + target t, Kernel kernel, dim3 grid_size, dim3 block_size, size_t shmem, hipStream_t stream) +{ + const auto launch_plan = make_launch_plan(t, kernel); launch_plan.launch(grid_size, block_size, shmem, stream); return hipGetLastError(); } -#ifdef ROCPRIM_EXPERIMENTAL_SPIRV -template -#else -template -#endif -auto dispatch_target_arch([[maybe_unused]] const target_arch target_arch) -{ - if constexpr(!ForceUnknownArch) - { - switch(target_arch) - { - case target_arch::invalid: - assert(false && "Invalid target architecture selected at runtime."); - break; -#define X(ID) case target_arch::ID: return Config::template architecture_config::params - X(unknown); - X(gfx803); - X(gfx900); - X(gfx906); - X(gfx908); - X(gfx90a); - X(gfx942); - X(gfx950); - X(gfx1030); - X(gfx1100); - X(gfx1102); - X(gfx1152); - X(gfx1153); - X(gfx1200); - X(gfx1201); -#undef X - } - } - return Config::template architecture_config::params; -} - inline target_arch parse_gcn_arch(const char* arch_name) { static constexpr auto length = sizeof(hipDeviceProp_t::gcnArchName); @@ -602,6 +687,33 @@ inline hipError_t host_target_arch(const hipStream_t stream, target_arch& arch) return get_device_arch(device_id, arch); } +constexpr gpu get_target_gpu_from_name(std::string_view name) +{ + for(const auto& each : target_gpu_names) + { + // Look for a substring in the marketing name, e.g., + // "RX 7900" in "AMD Radeon RX 7900 XTX". + if(name.find(std::get<0>(each)) != name.npos) + { + return std::get<1>(each); + } + } + return gpu::generic; +} + +inline hipError_t host_target_gpu(const hipStream_t stream, gpu& gpu) +{ + int device_id; + ROCPRIM_RETURN_ON_ERROR(get_device_from_stream(stream, device_id)); + + hipDeviceProp_t prop; + ROCPRIM_RETURN_ON_ERROR(hipGetDeviceProperties(&prop, device_id)); + + gpu = get_target_gpu_from_name(prop.name); + + return hipSuccess; +} + } // end namespace detail /// \brief Returns a number of threads in a hardware warp for the actual device. diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp index 875c58f1454..590cadc9df0 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp @@ -40,635 +40,538 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_adjacent_difference_config - : default_adjacent_difference_config_base::type -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<1024, 1> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 1> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<1024, 2> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<1024, 1> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 1> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 7> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<512, 1> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<1024, 2> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<64, 7> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 1> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 7> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 17> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<256, 1> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<1024, 2> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<64, 17> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 2> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 7> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<512, 1> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 1> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 3> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<128, 7> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 1> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 2> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 7> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<32, 17> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<256, 2> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 5> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<256, 11> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<256, 2> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 5> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 17> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<256, 1> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 1> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 3> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<128, 7> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 1> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 2> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 7> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<32, 17> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<64, 13> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<64, 31> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<64, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<64, 13> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<64, 29> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<256, 17> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<128, 19> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<32, 31> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<256, 17> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<128, 3> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 1> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 29> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<128, 17> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<1024, 3> -{}; +template +constexpr auto adjacent_difference_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {1024, 1} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 1} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {1024, 2} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {1024, 1} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 1} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {1024, 5} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {64, 7} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {512, 1} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {1024, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 5} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {64, 7} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {512, 1} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 5} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {64, 7} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {64, 17} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {256, 1} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {1024, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 5} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {64, 17} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 5} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {64, 7} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {64, 19} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {512, 1} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 1} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {128, 3} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {128, 7} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 1} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {128, 2} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {64, 7} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {64, 19} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {32, 17} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {256, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {512, 5} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {256, 11} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {256, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {512, 5} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {64, 17} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {64, 19} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {256, 1} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {64, 13} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {64, 31} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {64, 19} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {512, 2} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {64, 13} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {64, 29} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {256, 17} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {128, 19} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {32, 31} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 5} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {256, 17} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {128, 3} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 1} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {128, 29} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {128, 17} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {1024, 3} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + return adjacent_difference_config_picker< + comp_target, + value_type>(); +} + +// All the existing configs should be auto generated +using adjacent_difference_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp index 492a0e71bfa..81327e83498 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp @@ -40,635 +40,538 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_adjacent_difference_inplace_config - : default_adjacent_difference_config_base::type -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<32, 17> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<32, 17> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<1024, 13> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<32, 17> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<32, 29> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<512, 11> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<1024, 11> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<64, 29> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 29> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 29> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<128, 23> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 29> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 29> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<128, 23> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<512, 31> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<32, 29> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 11> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<256, 11> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<64, 31> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 11> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<256, 11> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 17> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<64, 29> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<512, 7> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<32, 23> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<32, 19> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 5> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<256, 11> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 5> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 17> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<128, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<512, 7> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<32, 23> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<32, 19> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 2> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<256, 3> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<512, 13> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<32, 17> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<256, 11> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 5> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 31> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<128, 23> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<256, 11> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 17> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<32, 17> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<32, 29> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<32, 7> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<32, 3> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<128, 11> -{}; +template +constexpr auto adjacent_difference_inplace_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {32, 17} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {32, 17} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {1024, 13} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {32, 17} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {32, 29} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {512, 11} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {1024, 11} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {64, 29} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_inplace_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 29} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {128, 29} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {128, 23} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 29} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {128, 29} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {128, 23} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {512, 31} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {32, 29} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_inplace_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 11} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {256, 11} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {64, 31} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 11} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {256, 11} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {64, 17} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {64, 19} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {64, 29} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_inplace_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {512, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 5} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {512, 7} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {512, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 5} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {32, 23} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {64, 19} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {32, 19} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_inplace_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {512, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {512, 5} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {256, 11} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {512, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {512, 5} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {64, 17} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {128, 19} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {512, 2} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_inplace_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {256, 3} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {512, 13} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {32, 17} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {256, 11} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {512, 5} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {64, 31} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {128, 23} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_inplace_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {256, 11} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {128, 17} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {32, 17} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {32, 29} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {32, 7} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {32, 3} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {512, 2} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {128, 11} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_inplace_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + return adjacent_difference_inplace_config_picker< + comp_target, + value_type>(); +} + +// All the existing configs should be auto generated +using adjacent_difference_inplace_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_find.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_find.hpp index 8d601e061d3..84a59486271 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_find.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_find.hpp @@ -40,702 +40,604 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_adjacent_find_config : default_adjacent_find_config_base::type -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1030), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<512, 2> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1030), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<256, 16> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1030), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<256, 16> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1030), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<512, 2> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1030), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<256, 8> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1030), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<256, 16> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1030), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<128, 16> -{}; - -// Based on input_type = rocprim::int128_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1030), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 16) && (sizeof(input_type) > 8))>> - : adjacent_find_config<512, 2> -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1100), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<512, 2> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1100), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<512, 8> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1100), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<512, 16> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1100), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<512, 8> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1100), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<512, 8> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1100), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<512, 16> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1100), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<128, 32> -{}; - -// Based on input_type = rocprim::int128_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1100), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 16) && (sizeof(input_type) > 8))>> - : adjacent_find_config<1024, 4> -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1200), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<64, 2> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1200), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<128, 4> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1200), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<64, 64> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1200), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<64, 4> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1200), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<128, 32> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1200), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<256, 32> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1200), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<256, 64> -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1201), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<64, 2> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1201), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<64, 32> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1201), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<256, 16> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1201), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<512, 2> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1201), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<64, 16> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1201), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<256, 8> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1201), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<64, 16> -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx906), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<64, 32> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx906), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<128, 16> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx906), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<64, 16> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx906), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<128, 4> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx906), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<128, 16> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx906), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<64, 16> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx906), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<64, 16> -{}; - -// Based on input_type = rocprim::int128_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx906), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 16) && (sizeof(input_type) > 8))>> - : adjacent_find_config<1024, 4> -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx908), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<128, 8> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx908), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<64, 16> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx908), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<64, 16> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx908), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<128, 32> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx908), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<512, 4> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx908), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<64, 16> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx908), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<64, 16> -{}; - -// Based on input_type = rocprim::int128_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx908), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 16) && (sizeof(input_type) > 8))>> - : adjacent_find_config<1024, 4> -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx90a), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<64, 8> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx90a), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<64, 16> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx90a), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<128, 16> -{}; - -// Based on input_type = rocprim::int128_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx90a), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 16) && (sizeof(input_type) > 8))>> - : adjacent_find_config<128, 32> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx90a), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<64, 8> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx90a), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<64, 16> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx90a), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<128, 16> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx90a), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<64, 16> -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx942), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<1024, 16> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx942), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<512, 32> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx942), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<512, 32> -{}; - -// Based on input_type = rocprim::int128_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx942), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 16) && (sizeof(input_type) > 8))>> - : adjacent_find_config<1024, 16> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx942), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<1024, 16> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx942), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<512, 32> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx942), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<512, 32> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx942), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<512, 32> -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::unknown), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<128, 8> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::unknown), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<64, 16> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::unknown), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<64, 16> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::unknown), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<128, 32> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::unknown), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<64, 64> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::unknown), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<64, 16> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::unknown), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<64, 16> -{}; - -// Based on input_type = rocprim::int128_t -template -struct default_adjacent_find_config< - static_cast(target_arch::unknown), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 16) && (sizeof(input_type) > 8))>> - : adjacent_find_config<1024, 4> -{}; - -// Based on input_type = rocprim::int128_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1201), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 16) && (sizeof(input_type) > 8))>> - : adjacent_find_config<128, 16> -{}; +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + // Based on input_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {512, 2} + }; + } + // Based on input_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {256, 16} + }; + } + // Based on input_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2))) + { + return adjacent_find_config_params{ + {256, 16} + }; + } + // Based on input_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {512, 2} + }; + } + // Based on input_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {256, 8} + }; + } + // Based on input_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2) + && (sizeof(input_type) > 1))) + { + return adjacent_find_config_params{ + {256, 16} + }; + } + // Based on input_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))) + { + return adjacent_find_config_params{ + {128, 16} + }; + } + // Based on input_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 16) + && (sizeof(input_type) > 8))) + { + return adjacent_find_config_params{ + {512, 2} + }; + } + // Default case if none of the conditions match + return adjacent_find_config_params_base(); +} + +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + // Based on input_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {512, 2} + }; + } + // Based on input_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {512, 8} + }; + } + // Based on input_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2))) + { + return adjacent_find_config_params{ + {512, 16} + }; + } + // Based on input_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {512, 8} + }; + } + // Based on input_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {512, 8} + }; + } + // Based on input_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2) + && (sizeof(input_type) > 1))) + { + return adjacent_find_config_params{ + {512, 16} + }; + } + // Based on input_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))) + { + return adjacent_find_config_params{ + {128, 32} + }; + } + // Based on input_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 16) + && (sizeof(input_type) > 8))) + { + return adjacent_find_config_params{ + {1024, 4} + }; + } + // Default case if none of the conditions match + return adjacent_find_config_params_base(); +} + +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + // Based on input_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {64, 2} + }; + } + // Based on input_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {128, 4} + }; + } + // Based on input_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2))) + { + return adjacent_find_config_params{ + {64, 64} + }; + } + // Based on input_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {64, 4} + }; + } + // Based on input_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {128, 32} + }; + } + // Based on input_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2) + && (sizeof(input_type) > 1))) + { + return adjacent_find_config_params{ + {256, 32} + }; + } + // Based on input_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))) + { + return adjacent_find_config_params{ + {256, 64} + }; + } + // Default case if none of the conditions match + return adjacent_find_config_params_base(); +} + +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + // Based on input_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {64, 2} + }; + } + // Based on input_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {64, 32} + }; + } + // Based on input_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2))) + { + return adjacent_find_config_params{ + {256, 16} + }; + } + // Based on input_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {512, 2} + }; + } + // Based on input_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2) + && (sizeof(input_type) > 1))) + { + return adjacent_find_config_params{ + {256, 8} + }; + } + // Based on input_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 16) + && (sizeof(input_type) > 8))) + { + return adjacent_find_config_params{ + {128, 16} + }; + } + // Default case if none of the conditions match + return adjacent_find_config_params_base(); +} + +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + // Based on input_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {64, 32} + }; + } + // Based on input_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {128, 16} + }; + } + // Based on input_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {128, 4} + }; + } + // Based on input_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {128, 16} + }; + } + // Based on input_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2) + && (sizeof(input_type) > 1))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 16) + && (sizeof(input_type) > 8))) + { + return adjacent_find_config_params{ + {1024, 4} + }; + } + // Default case if none of the conditions match + return adjacent_find_config_params_base(); +} + +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + // Based on input_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {128, 8} + }; + } + // Based on input_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {128, 32} + }; + } + // Based on input_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {512, 4} + }; + } + // Based on input_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2) + && (sizeof(input_type) > 1))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 16) + && (sizeof(input_type) > 8))) + { + return adjacent_find_config_params{ + {1024, 4} + }; + } + // Default case if none of the conditions match + return adjacent_find_config_params_base(); +} + +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + // Based on input_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {64, 8} + }; + } + // Based on input_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2))) + { + return adjacent_find_config_params{ + {128, 16} + }; + } + // Based on input_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 16) + && (sizeof(input_type) > 8))) + { + return adjacent_find_config_params{ + {128, 32} + }; + } + // Based on input_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {64, 8} + }; + } + // Based on input_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2) + && (sizeof(input_type) > 1))) + { + return adjacent_find_config_params{ + {128, 16} + }; + } + // Based on input_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Default case if none of the conditions match + return adjacent_find_config_params_base(); +} + +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + // Based on input_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {1024, 16} + }; + } + // Based on input_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {512, 32} + }; + } + // Based on input_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2))) + { + return adjacent_find_config_params{ + {512, 32} + }; + } + // Based on input_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 16) + && (sizeof(input_type) > 8))) + { + return adjacent_find_config_params{ + {1024, 16} + }; + } + // Based on input_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {1024, 16} + }; + } + // Based on input_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {512, 32} + }; + } + // Based on input_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2) + && (sizeof(input_type) > 1))) + { + return adjacent_find_config_params{ + {512, 32} + }; + } + // Based on input_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))) + { + return adjacent_find_config_params{ + {512, 32} + }; + } + // Default case if none of the conditions match + return adjacent_find_config_params_base(); +} + +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + return adjacent_find_config_picker< + comp_target, + input_type>(); +} + +// All the existing configs should be auto generated +using adjacent_find_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_batch_copy.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_batch_copy.hpp index e38b1621b4f..99d1f468357 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_batch_copy.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_batch_copy.hpp @@ -38,255 +38,350 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_batch_copy_config : default_batch_memcpy_config_base::type -{}; - -// Based on value_type = int64_t -template -struct default_batch_copy_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_copy_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_copy_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_copy_config(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_copy_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_copy_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_copy_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_copy_config(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_copy_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_copy_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_copy_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_copy_config(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_copy_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_copy_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_copy_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_copy_config(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_copy_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_copy_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_copy_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_copy_config(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_copy_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_copy_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_copy_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_copy_config(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_copy_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_copy_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_copy_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_copy_config(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; +template +constexpr auto batch_copy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_copy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_copy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_copy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_copy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_copy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_copy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + return batch_copy_config_picker< + comp_target, + value_type>(); +} + +// All the existing configs should be auto generated +using batch_copy_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_batch_memcpy.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_batch_memcpy.hpp index a3d1b21a33e..efeb5e0e689 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_batch_memcpy.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_batch_memcpy.hpp @@ -38,255 +38,350 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_batch_memcpy_config : default_batch_memcpy_config_base::type -{}; - -// Based on value_type = int64_t -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_memcpy_config(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_memcpy_config(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_memcpy_config(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_memcpy_config(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_memcpy_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_memcpy_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_memcpy_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_memcpy_config(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_memcpy_config(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_memcpy_config(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; +template +constexpr auto batch_memcpy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_memcpy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_memcpy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_memcpy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_memcpy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_memcpy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_memcpy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + return batch_memcpy_config_picker< + comp_target, + value_type>(); +} + +// All the existing configs should be auto generated +using batch_memcpy_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_binary_search.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_binary_search.hpp index f2705bfd361..bdaf163efd8 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_binary_search.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_binary_search.hpp @@ -40,4045 +40,2756 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_binary_search_config : default_binary_search_config_base -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 4> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 4> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<256, 16> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<256, 8> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 8> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 16> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 16> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<64, 16> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 16> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<64, 8> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<64, 8> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 8> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<256, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<64, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 16> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<128, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<64, 4> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 2> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<64, 8> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 2> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 8> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<128, 2> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<128, 2> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 2> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 8> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<128, 2> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<256, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<256, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<256, 8> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 16> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<128, 1> -{}; +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 16} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 2} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + return binary_search_config_picker< + comp_target, + value_type, + output_type>(); +} + +// All the existing configs should be auto generated +using binary_search_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_find_first_of.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_find_first_of.hpp index 929d552f288..d60b0c2f86d 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_find_first_of.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_find_first_of.hpp @@ -40,397 +40,381 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_find_first_of_config : default_find_first_of_config_base::type -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<256, 10> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<256, 12> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<256, 12> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<64, 15> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : find_first_of_config<256, 4> -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<256, 9> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<128, 13> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<256, 9> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<64, 13> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : find_first_of_config<256, 4> -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<128, 15> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<256, 12> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<128, 16> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::gfx1200), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<256, 16> -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<256, 6> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<256, 10> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<64, 8> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::gfx1201), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<128, 16> -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<256, 15> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<1024, 14> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<64, 16> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<256, 11> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : find_first_of_config<256, 4> -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<256, 8> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<256, 10> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<256, 11> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<256, 10> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : find_first_of_config<256, 4> -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<256, 6> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<128, 9> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<256, 15> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<256, 10> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : find_first_of_config<256, 4> -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<256, 8> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<256, 10> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<256, 11> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<256, 10> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_find_first_of_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : find_first_of_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : find_first_of_config<1024, 6> -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<1024, 7> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<1024, 6> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<1024, 9> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<1024, 11> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : find_first_of_config<256, 3> -{}; +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return find_first_of_config_params{ + {256, 10} + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return find_first_of_config_params{ + {256, 12} + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return find_first_of_config_params{ + {256, 12} + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return find_first_of_config_params{ + {64, 15} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr(((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return find_first_of_config_params{ + {256, 4} + }; + } + // Default case if none of the conditions match + return find_first_of_config_params_base(); +} + +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return find_first_of_config_params{ + {256, 9} + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return find_first_of_config_params{ + {128, 13} + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return find_first_of_config_params{ + {256, 9} + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return find_first_of_config_params{ + {64, 13} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr(((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return find_first_of_config_params{ + {256, 4} + }; + } + // Default case if none of the conditions match + return find_first_of_config_params_base(); +} + +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return find_first_of_config_params{ + {128, 15} + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return find_first_of_config_params{ + {256, 12} + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return find_first_of_config_params{ + {128, 16} + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return find_first_of_config_params{ + {256, 16} + }; + } + // Default case if none of the conditions match + return find_first_of_config_params_base(); +} + +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return find_first_of_config_params{ + {256, 6} + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return find_first_of_config_params{ + {256, 10} + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return find_first_of_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return find_first_of_config_params{ + {128, 16} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr(((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return find_first_of_config_params{ + {256, 3} + }; + } + // Default case if none of the conditions match + return find_first_of_config_params_base(); +} + +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return find_first_of_config_params{ + {256, 15} + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return find_first_of_config_params{ + {1024, 14} + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return find_first_of_config_params{ + {64, 16} + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return find_first_of_config_params{ + {256, 11} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr(((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return find_first_of_config_params{ + {256, 4} + }; + } + // Default case if none of the conditions match + return find_first_of_config_params_base(); +} + +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return find_first_of_config_params{ + {256, 8} + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return find_first_of_config_params{ + {256, 10} + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return find_first_of_config_params{ + {256, 11} + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return find_first_of_config_params{ + {256, 10} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr(((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return find_first_of_config_params{ + {256, 4} + }; + } + // Default case if none of the conditions match + return find_first_of_config_params_base(); +} + +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return find_first_of_config_params{ + {256, 6} + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return find_first_of_config_params{ + {128, 9} + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return find_first_of_config_params{ + {256, 15} + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return find_first_of_config_params{ + {256, 10} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr(((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return find_first_of_config_params{ + {256, 4} + }; + } + // Default case if none of the conditions match + return find_first_of_config_params_base(); +} + +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + // Based on value_type = rocprim::int128_t + if constexpr(((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return find_first_of_config_params{ + {1024, 6} + }; + } + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return find_first_of_config_params{ + {1024, 7} + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return find_first_of_config_params{ + {1024, 6} + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return find_first_of_config_params{ + {1024, 9} + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return find_first_of_config_params{ + {1024, 11} + }; + } + // Default case if none of the conditions match + return find_first_of_config_params_base(); +} + +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + return find_first_of_config_picker< + comp_target, + value_type>(); +} + +// All the existing configs should be auto generated +using find_first_of_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_histogram.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_histogram.hpp index 7c73fde9b6e..052ff763ec0 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_histogram.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_histogram.hpp @@ -40,3585 +40,2779 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_histogram_config - : default_histogram_config_base::type -{}; - -// Based on value_type = double, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = double, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = double, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = float, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = float, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = float, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = float, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = rocprim::half, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::half, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = int, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = short, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = short, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = float, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 3, kernel_config<1024, 4>> -{}; - -// Based on value_type = double, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = double, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = double, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = double, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = float, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = float, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = float, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = float, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = float, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::half, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::half, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::half, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 2, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 2, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 3, kernel_config<1024, 4>> -{}; - -// Based on value_type = int64_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 3, kernel_config<1024, 4>> -{}; - -// Based on value_type = int64_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3, kernel_config<1024, 4>> -{}; - -// Based on value_type = int64_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = short, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = short, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = short, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = short, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = short, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int8_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int8_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int8_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; +template +constexpr auto histogram_config_picker() -> std::enable_if_t< + std::is_same>::value, + histogram_config_params> +{ + // Based on value_type = double, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = double, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 4}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 3}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 3}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{128, 5}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{64, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 9}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 2}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 9}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 3}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{128, 15}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 8}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 3}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Default case if none of the conditions match + return histogram_config_params_base(); +} + +template +constexpr auto histogram_config_picker() -> std::enable_if_t< + std::is_same>::value, + histogram_config_params> +{ + // Based on value_type = double, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = double, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 3}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 2}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = double, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = float, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 6}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = float, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = float, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 3}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = float, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = rocprim::half, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = rocprim::half, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 3}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 2}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = int64_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = int, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = short, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = short, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{128, 12}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 6}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{128, 3}, + 2048, + 2048, + 4 + }; + } + // Default case if none of the conditions match + return histogram_config_params_base(); +} + +template +constexpr auto histogram_config_picker() -> std::enable_if_t< + std::is_same>::value, + histogram_config_params> +{ + // Same fallback as previous config system. + return histogram_config_params_base(); +} + +template +constexpr auto histogram_config_picker() -> std::enable_if_t< + std::is_same>::value, + histogram_config_params> +{ + // Based on value_type = double, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 6}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = double, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = float, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 14}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 7}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 15}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 3}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 14}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 7}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 16}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{64, 10}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Default case if none of the conditions match + return histogram_config_params_base(); +} + +template +constexpr auto histogram_config_picker() -> std::enable_if_t< + std::is_same>::value, + histogram_config_params> +{ + // Based on value_type = double, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = double, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 6}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = double, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = double, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 13}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 7}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 15}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 3}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 3}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = int64_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = int, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 15}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 7}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 16}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 3}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 16}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Default case if none of the conditions match + return histogram_config_params_base(); +} + +template +constexpr auto histogram_config_picker() -> std::enable_if_t< + std::is_same>::value, + histogram_config_params> +{ + // Based on value_type = double, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 10}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = double, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 10}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{64, 8}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = int64_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 3}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 10}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 12}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{128, 15}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Default case if none of the conditions match + return histogram_config_params_base(); +} + +template +constexpr auto histogram_config_picker() -> std::enable_if_t< + std::is_same>::value, + histogram_config_params> +{ + // Based on value_type = double, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{ 256, 16}, + 2048, + 2048, + 3, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = double, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{ 128, 2}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = double, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 5}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = double, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 128, 1}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = double, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{ 256, 1}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = float, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{ 256, 15}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = float, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{ 128, 3}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = float, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 5}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = float, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = float, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{ 256, 3}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::half, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{ 256, 15}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::half, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{ 256, 8}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::half, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 2, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{ 256, 2}, + 2048, + 2048, + 2, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 2}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 2}, + 2048, + 2048, + 3, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{ 256, 1}, + 2048, + 2048, + 3, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int64_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{ 256, 6}, + 2048, + 2048, + 3, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int64_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{ 256, 3}, + 2048, + 2048, + 3, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int64_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 3}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 1}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{ 256, 3}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{ 256, 9}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{ 128, 6}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 2}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{ 256, 2}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = short, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{ 256, 16}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = short, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{ 256, 8}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = short, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 5}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = short, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = short, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int8_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{ 256, 12}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int8_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{ 256, 8}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int8_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 5}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Default case if none of the conditions match + return histogram_config_params_base(); +} + +template +constexpr auto histogram_config_picker() -> std::enable_if_t< + std::is_same>::value, + histogram_config_params> +{ + return histogram_config_picker< + comp_target, + value_type, + channels, + active_channels>(); +} + +// All the existing configs should be auto generated +using histogram_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_lower_bound.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_lower_bound.hpp index 13e05329597..ed8e240ca2f 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_lower_bound.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_lower_bound.hpp @@ -40,4045 +40,2756 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_lower_bound_config : default_binary_search_config_base -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 4> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 4> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 16> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 16> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<128, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<128, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<128, 8> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 16> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 16> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 16> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 16> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<64, 16> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 8> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 8> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 8> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 8> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<256, 8> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 16> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 2> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 2> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 2> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<64, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<128, 2> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 2> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 16> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 4> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<64, 2> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<128, 2> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<128, 2> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 2> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<64, 8> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 8> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 8> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<64, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<64, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<128, 4> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 16> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<128, 1> -{}; +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 8} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 16} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 16} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 16} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 2} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + return lower_bound_config_picker< + comp_target, + value_type, + output_type>(); +} + +// All the existing configs should be auto generated +using lower_bound_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge.hpp index a970764bb5c..2d73e9ba88f 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge.hpp @@ -40,4743 +40,3370 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_merge_config : default_merge_config_base::type -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 16> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 16> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 10> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<256, 5> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 16> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 16> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 10> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<256, 11> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<256, 10> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<1024, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_config<256, 8> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 8> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<32, 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 16> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 16> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 16> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 8> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 16> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 16> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 16> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 16> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 16> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 7> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 7> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 2> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<512, 2> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<256, 11> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_config<512, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 7> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 7> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_config<512, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 2> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<256, 16> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_config<512, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 7> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 7> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_config<512, 2> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<256, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<256, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 2> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<512, 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<256, 11> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_config<256, 2> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<256, 2> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 2> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 2> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<256, 16> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<256, 16> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<32, 1> -{}; +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 16} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 16} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 11} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 16} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 16} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 11} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 16} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {256, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 8} + }; + } + // Default case if none of the conditions match + return merge_config_params_base(); +} + +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {32, 1} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {256, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Default case if none of the conditions match + return merge_config_params_base(); +} + +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Default case if none of the conditions match + return merge_config_params_base(); +} + +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {32, 1} + }; + } + // Default case if none of the conditions match + return merge_config_params_base(); +} + +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 7} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 11} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 7} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 8} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 11} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 2} + }; + } + // Default case if none of the conditions match + return merge_config_params_base(); +} + +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 7} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 8} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 7} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 2} + }; + } + // Default case if none of the conditions match + return merge_config_params_base(); +} + +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 11} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 16} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 16} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 2} + }; + } + // Default case if none of the conditions match + return merge_config_params_base(); +} + +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {256, 2} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {256, 4} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 8} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 1} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 11} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 2} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {256, 2} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 4} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 8} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 16} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 16} + }; + } + // Default case if none of the conditions match + return merge_config_params_base(); +} + +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + return merge_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using merge_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_merge.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_merge.hpp index d07813ba5bd..ef6c43ecf98 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_merge.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_merge.hpp @@ -40,4542 +40,3724 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_merge_sort_block_merge_config - : merge_sort_block_merge_config_base::type -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 2> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 2> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; +template +constexpr auto merge_sort_block_merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_merge_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 1} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 4} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Default case if none of the conditions match + return merge_sort_block_merge_config_params_base(); +} + +template +constexpr auto merge_sort_block_merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_merge_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 1} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 2} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 2} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Default case if none of the conditions match + return merge_sort_block_merge_config_params_base(); +} + +template +constexpr auto merge_sort_block_merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_merge_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Default case if none of the conditions match + return merge_sort_block_merge_config_params_base(); +} + +template +constexpr auto merge_sort_block_merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_merge_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Default case if none of the conditions match + return merge_sort_block_merge_config_params_base(); +} + +template +constexpr auto merge_sort_block_merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_merge_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Default case if none of the conditions match + return merge_sort_block_merge_config_params_base(); +} + +template +constexpr auto merge_sort_block_merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_merge_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Default case if none of the conditions match + return merge_sort_block_merge_config_params_base(); +} + +template +constexpr auto merge_sort_block_merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_merge_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 1} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 1} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Default case if none of the conditions match + return merge_sort_block_merge_config_params_base(); +} + +template +constexpr auto merge_sort_block_merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_merge_config_params> +{ + return merge_sort_block_merge_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using merge_sort_block_merge_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_sort.hpp index aabb60332e8..135176444d5 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_sort.hpp @@ -40,4824 +40,2675 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_merge_sort_block_sort_config - : merge_sort_block_sort_config_base::type -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = double, value_type = custom_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = custom_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::half, value_type = custom_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = int64_t, value_type = custom_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = custom_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = short, value_type = custom_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = int8_t, value_type = custom_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<512, 16> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Default case if none of the conditions match + return merge_sort_block_sort_config_params_base(); +} + +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Default case if none of the conditions match + return merge_sort_block_sort_config_params_base(); +} + +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + // Based on key_type = double, value_type = custom_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = custom_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = rocprim::half, value_type = custom_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Based on key_type = int64_t, value_type = custom_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = custom_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = short, value_type = custom_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Based on key_type = int8_t, value_type = custom_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Default case if none of the conditions match + return merge_sort_block_sort_config_params_base(); +} + +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 16}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Default case if none of the conditions match + return merge_sort_block_sort_config_params_base(); +} + +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Default case if none of the conditions match + return merge_sort_block_sort_config_params_base(); +} + +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Default case if none of the conditions match + return merge_sort_block_sort_config_params_base(); +} + +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Default case if none of the conditions match + return merge_sort_block_sort_config_params_base(); +} + +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Default case if none of the conditions match + return merge_sort_block_sort_config_params_base(); +} + +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + return merge_sort_block_sort_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using merge_sort_block_sort_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_flag.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_flag.hpp index 0c0a1b540f0..141b6bf1fd2 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_flag.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_flag.hpp @@ -40,608 +40,523 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_partition_flag_config : default_partition_config_base::type -{}; - -// Based on data_type = double -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = float -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 5> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 8> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 5> -{}; - -// Based on data_type = short -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 12> -{}; - -// Based on data_type = int8_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<128, 12> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 20> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<128, 12> -{}; - -// Based on data_type = short -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<128, 19> -{}; - -// Based on data_type = int8_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<128, 28> -{}; - -// Based on data_type = double -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 4> -{}; - -// Based on data_type = float -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 6> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 12> -{}; - -// Based on data_type = int64_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 4> -{}; - -// Based on data_type = int -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<384, 7> -{}; - -// Based on data_type = short -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 12> -{}; - -// Based on data_type = int8_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 7> -{}; - -// Based on data_type = float -template -struct default_partition_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 13> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 28> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_partition_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 13> -{}; - -// Based on data_type = short -template -struct default_partition_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 20> -{}; - -// Based on data_type = double -template -struct default_partition_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_partition_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<128, 12> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 28> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_partition_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<128, 12> -{}; - -// Based on data_type = short -template -struct default_partition_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<192, 4> -{}; - -// Based on data_type = float -template -struct default_partition_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 20> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<192, 4> -{}; - -// Based on data_type = int -template -struct default_partition_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 10> -{}; - -// Based on data_type = short -template -struct default_partition_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 20> -{}; - -// Based on data_type = int8_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_partition_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<128, 12> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 28> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_partition_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<128, 12> -{}; - -// Based on data_type = short -template -struct default_partition_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_partition_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 28> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_partition_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = short -template -struct default_partition_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 30> -{}; - -// Based on data_type = int8_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 24> -{}; +template +constexpr auto partition_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 5} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 5} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 12} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {128, 19} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {128, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 28} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return partition_flag_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using partition_flag_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_predicate.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_predicate.hpp index 80658c1166e..29bce571c57 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_predicate.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_predicate.hpp @@ -40,608 +40,523 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_partition_predicate_config : default_partition_config_base::type -{}; - -// Based on data_type = double -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = float -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 4> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 8> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 4> -{}; - -// Based on data_type = short -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 30> -{}; - -// Based on data_type = int8_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<128, 24> -{}; - -// Based on data_type = double -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 8> -{}; - -// Based on data_type = float -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 9> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<384, 18> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = int -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<384, 9> -{}; - -// Based on data_type = short -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<128, 30> -{}; - -// Based on data_type = int8_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = float -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 7> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 24> -{}; - -// Based on data_type = int64_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<384, 7> -{}; - -// Based on data_type = short -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 24> -{}; - -// Based on data_type = int8_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 28> -{}; - -// Based on data_type = double -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 7> -{}; - -// Based on data_type = float -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<192, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 26> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<192, 15> -{}; - -// Based on data_type = short -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 26> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<192, 7> -{}; - -// Based on data_type = int -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 15> -{}; - -// Based on data_type = short -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<192, 4> -{}; - -// Based on data_type = float -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 20> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<192, 4> -{}; - -// Based on data_type = int -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 10> -{}; - -// Based on data_type = short -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 20> -{}; - -// Based on data_type = int8_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 24> -{}; - -// Based on data_type = double -template -struct default_partition_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_partition_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 26> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<192, 7> -{}; - -// Based on data_type = int -template -struct default_partition_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 15> -{}; - -// Based on data_type = short -template -struct default_partition_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_partition_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 30> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = short -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 30> -{}; - -// Based on data_type = int8_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 23> -{}; +template +constexpr auto partition_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {128, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 9} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 9} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {128, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 24} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {192, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 26} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {192, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 26} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 23} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return partition_predicate_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using partition_predicate_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_three_way.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_three_way.hpp index eac8d66329e..53ababfdf51 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_three_way.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_three_way.hpp @@ -40,608 +40,523 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_partition_three_way_config : default_partition_config_base::type -{}; - -// Based on data_type = double -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = float -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 6> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 12> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 6> -{}; - -// Based on data_type = short -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 30> -{}; - -// Based on data_type = int8_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 8> -{}; - -// Based on data_type = float -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 17> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = int -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 9> -{}; - -// Based on data_type = short -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 18> -{}; - -// Based on data_type = int8_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 20> -{}; - -// Based on data_type = double -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<384, 4> -{}; - -// Based on data_type = float -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 26> -{}; - -// Based on data_type = int64_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<384, 4> -{}; - -// Based on data_type = int -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 10> -{}; - -// Based on data_type = short -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 7> -{}; - -// Based on data_type = float -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 14> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 15> -{}; - -// Based on data_type = short -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 15> -{}; - -// Based on data_type = int8_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 15> -{}; - -// Based on data_type = double -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<192, 6> -{}; - -// Based on data_type = float -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 14> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 15> -{}; - -// Based on data_type = short -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 14> -{}; - -// Based on data_type = int8_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 15> -{}; - -// Based on data_type = double -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 6> -{}; - -// Based on data_type = float -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 11> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 16> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = int -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 11> -{}; - -// Based on data_type = short -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 16> -{}; - -// Based on data_type = int8_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 17> -{}; - -// Based on data_type = double -template -struct default_partition_three_way_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<192, 6> -{}; - -// Based on data_type = float -template -struct default_partition_three_way_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_three_way_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 14> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_three_way_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_partition_three_way_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_partition_three_way_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 15> -{}; - -// Based on data_type = short -template -struct default_partition_three_way_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 14> -{}; - -// Based on data_type = int8_t -template -struct default_partition_three_way_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 15> -{}; - -// Based on data_type = double -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 18> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = short -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 18> -{}; - -// Based on data_type = int8_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 18> -{}; +template +constexpr auto partition_three_way_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 12} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_three_way_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 17} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 9} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_three_way_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 26} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 10} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 28} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_three_way_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 15} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_three_way_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 6} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 15} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_three_way_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 17} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_three_way_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 18} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_three_way_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return partition_three_way_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using partition_three_way_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_two_way_flag.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_two_way_flag.hpp index 82dfb30f2d3..ac23fa8ec65 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_two_way_flag.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_two_way_flag.hpp @@ -40,608 +40,523 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_partition_two_way_flag_config : default_partition_config_base::type -{}; - -// Based on data_type = double -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 4> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 6> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 4> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 12> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<384, 20> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 8> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<384, 18> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 6> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<128, 11> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<128, 17> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<128, 28> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 4> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 7> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<384, 12> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 4> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 6> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<384, 14> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 7> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 13> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 28> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 13> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 20> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 13> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 28> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<192, 4> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 20> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<192, 4> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 7> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 20> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 13> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 28> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 30> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 24> -{}; +template +constexpr auto partition_two_way_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 12} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {384, 20} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 11} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {128, 17} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {128, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {384, 12} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {384, 14} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 28} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return partition_two_way_flag_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using partition_two_way_flag_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_two_way_predicate.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_two_way_predicate.hpp index a6fdba90c1c..e2132d512b9 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_two_way_predicate.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_two_way_predicate.hpp @@ -40,609 +40,523 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_partition_two_way_predicate_config - : default_partition_config_base::type -{}; - -// Based on data_type = double -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 4> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 8> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 4> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<128, 30> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 8> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<192, 9> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<192, 12> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<128, 8> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<128, 18> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 4> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 6> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<384, 18> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 8> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<384, 18> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<384, 28> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 7> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 26> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<192, 15> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 26> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 13> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 6> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 20> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<192, 4> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 10> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 26> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 13> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 30> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 30> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 24> -{}; +template +constexpr auto partition_two_way_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {128, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {192, 9} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {192, 12} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 8} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {128, 18} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {384, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 26} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {192, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 26} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return partition_two_way_predicate_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using partition_two_way_predicate_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_block_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_block_sort.hpp index 9bd1f70448d..e29d1eeba08 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_block_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_block_sort.hpp @@ -40,4747 +40,2628 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_radix_sort_block_sort_config - : radix_sort_block_sort_config_base::type -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<128, 16> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 25> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 25> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 25> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<128, 28> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 15> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 25> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 29> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 29> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 31> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 25> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 29> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<128, 29> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<128, 29> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<128, 32> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<128, 13> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<128, 15> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 18> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : kernel_config<128, 16> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<128, 16> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 25> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 28> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 28> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 28> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<64, 28> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 25> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 28> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 30> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<128, 30> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 15> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 25> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 30> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 29> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<128, 30> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<512, 15> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<512, 30> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 30> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 31> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<128, 16> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<128, 25> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 25> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 25> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 28> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<128, 32> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 15> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<128, 25> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 31> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 30> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 30> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<128, 32> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 25> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 23> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<128, 31> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 27> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 10> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<128, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<128, 16> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<128, 25> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 25> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 25> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 28> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<128, 32> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 15> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<128, 25> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 31> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 30> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 30> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<64, 30> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 15> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 25> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 30> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 31> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<128, 30> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 30> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<128, 30> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 31> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 28> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 25> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 25> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 28> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<128, 32> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<128, 25> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 30> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 27> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 26> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<128, 32> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 24> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<128, 29> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<128, 29> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<128, 27> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<128, 32> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 28> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 30> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 28> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 28> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<128, 25> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 31> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 30> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 29> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<128, 32> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 24> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 29> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 29> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 27> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 30> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 29> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<128, 30> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 27> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<64, 31> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 22> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 25> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 22> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 28> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<128, 25> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 25> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 30> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 27> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 26> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 25> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<128, 27> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<64, 29> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 27> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<64, 30> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 22> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 28> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 28> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 28> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 25> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 31> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 30> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 29> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 25> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 29> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 29> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 27> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<512, 15> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 29> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<128, 30> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 27> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<64, 31> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 7> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 12> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 12> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 13> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 12> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 12> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 18> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 16> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 14> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<512, 23> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 12> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 19> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<256, 22> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 6> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 7> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : kernel_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 12> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 12> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 12> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 19> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 16> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<512, 23> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 12> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 14> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 19> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<256, 23> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 13> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 10> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 17> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<256, 22> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 7> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 11> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 11> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 12> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 15> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<512, 15> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<512, 30> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<1024, 15> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 29> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<512, 31> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<512, 15> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<512, 28> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<1024, 18> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 27> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<1024, 22> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 7> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : kernel_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<512, 15> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<512, 14> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<512, 31> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<512, 30> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 29> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<512, 31> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 12> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<512, 30> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<1024, 18> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<1024, 16> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<1024, 22> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 13> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<512, 30> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<1024, 16> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<1024, 15> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<1024, 23> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 10> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 16> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 16> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 15> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 16> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 16> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 14> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<256, 21> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 29> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<512, 18> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 29> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<256, 32> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : kernel_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 10> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 16> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 16> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 10> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 16> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 16> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 30> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<256, 21> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 10> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 30> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 23> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 30> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<256, 27> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<512, 5> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 30> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 21> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 23> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<256, 29> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 7> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 11> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 11> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 12> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 15> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<512, 15> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<512, 30> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<1024, 15> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 29> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<512, 31> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<512, 15> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<512, 28> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<1024, 18> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 27> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<1024, 22> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 7> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : kernel_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<512, 15> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<512, 14> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<512, 31> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<512, 30> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 29> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<512, 31> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 12> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<512, 30> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<1024, 18> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<1024, 16> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<1024, 22> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 13> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<512, 30> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<1024, 16> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<1024, 15> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<1024, 23> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 8> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 8> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 8> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 21> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 16> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 29> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<256, 32> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<512, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<512, 14> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<512, 18> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 14> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<512, 21> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 5> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 5> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 5> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 5> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 16> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 10> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 21> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 16> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 19> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<256, 21> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 16> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 29> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<512, 18> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 15> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<512, 22> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<512, 7> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<1024, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 21> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 21> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 23> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<512, 21> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<64, 16> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<128, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 17> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : kernel_config<64, 16> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<64, 16> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 15> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 15> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 25}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 25}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{128, 28}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 25}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 25}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 29}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return kernel_config_params{128, 32}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{128, 13}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{128, 15}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 18}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 25}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 28}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 28}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 25}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 28}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 25}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 30}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{512, 15}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{512, 30}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 30}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Default case if none of the conditions match + return radix_sort_block_sort_config_params_base(); +} + +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{128, 32}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{128, 32}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 25}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 23}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 31}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 27}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 10}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{128, 32}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 25}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 30}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Default case if none of the conditions match + return radix_sort_block_sort_config_params_base(); +} + +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{128, 32}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 27}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 26}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{128, 32}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 24}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{128, 27}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return kernel_config_params{128, 32}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 29}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{128, 32}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 24}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 29}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 27}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 30}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 29}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 27}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return kernel_config_params{64, 31}; + } + // Default case if none of the conditions match + return radix_sort_block_sort_config_params_base(); +} + +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 22}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 25}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 22}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 25}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 27}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 26}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 25}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 27}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 29}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 27}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 22}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 25}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 29}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 25}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 29}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 27}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{512, 15}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 29}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 27}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{64, 16}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 17}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return kernel_config_params{64, 16}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{64, 16}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Default case if none of the conditions match + return radix_sort_block_sort_config_params_base(); +} + +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 7}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 13}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 18}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 14}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{512, 23}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 19}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return kernel_config_params{256, 22}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 6}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 7}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 11}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 11}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 19}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{512, 23}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 14}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 19}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return kernel_config_params{256, 23}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 13}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 10}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 17}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return kernel_config_params{256, 22}; + } + // Default case if none of the conditions match + return radix_sort_block_sort_config_params_base(); +} + +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 7}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 11}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 11}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{512, 15}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{512, 15}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{512, 30}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{1024, 15}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{512, 29}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{512, 31}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{512, 15}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{512, 28}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{1024, 18}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{512, 27}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return kernel_config_params{1024, 22}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 7}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{512, 15}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 11}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 11}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{512, 14}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{512, 31}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{512, 30}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{512, 29}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{512, 31}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{512, 4}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{512, 30}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{1024, 18}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{1024, 16}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return kernel_config_params{1024, 22}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 13}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{512, 30}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{1024, 16}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{1024, 15}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return kernel_config_params{1024, 23}; + } + // Default case if none of the conditions match + return radix_sort_block_sort_config_params_base(); +} + +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 10}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 14}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{256, 21}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 10}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 29}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{512, 18}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 29}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return kernel_config_params{256, 32}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 10}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 10}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 30}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{256, 21}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 10}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 30}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 23}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 30}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return kernel_config_params{256, 27}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{512, 5}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 30}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 21}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 23}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return kernel_config_params{256, 29}; + } + // Default case if none of the conditions match + return radix_sort_block_sort_config_params_base(); +} + +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 21}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 29}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{256, 32}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{512, 8}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{512, 14}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{512, 18}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{512, 14}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return kernel_config_params{512, 21}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 5}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 5}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 5}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 5}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 10}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 21}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 19}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{256, 21}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 29}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{512, 18}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{512, 15}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return kernel_config_params{512, 22}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{512, 7}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{1024, 4}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 21}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 21}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 23}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return kernel_config_params{512, 21}; + } + // Default case if none of the conditions match + return radix_sort_block_sort_config_params_base(); +} + +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + return radix_sort_block_sort_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using radix_sort_block_sort_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp index 9e2fe8c795c..9c942f95e3f 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp @@ -40,6404 +40,4637 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_radix_sort_onesweep_config - : radix_sort_onesweep_config_base::type -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 1>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 1>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 1>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 1>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 1>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 1>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 1>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<256, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 6, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 6, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 22}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Default case if none of the conditions match + return radix_sort_onesweep_config_params_base(); +} + +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 18}, + kernel_config_params{512, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 18}, + kernel_config_params{512, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 22}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Default case if none of the conditions match + return radix_sort_onesweep_config_params_base(); +} + +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 18}, + kernel_config_params{512, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Default case if none of the conditions match + return radix_sort_onesweep_config_params_base(); +} + +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 4, + block_radix_rank_algorithm::basic + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 4, + block_radix_rank_algorithm::basic + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, + 4, + block_radix_rank_algorithm::basic + }; + } + // Default case if none of the conditions match + return radix_sort_onesweep_config_params_base(); +} + +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 22}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 22}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 22}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 4, + block_radix_rank_algorithm::basic + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 22}, + kernel_config_params{512, 22}, + 4, + block_radix_rank_algorithm::basic + }; + } + // Default case if none of the conditions match + return radix_sort_onesweep_config_params_base(); +} + +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 8}, + kernel_config_params{256, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 8}, + kernel_config_params{256, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 6}, + kernel_config_params{256, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 4, + block_radix_rank_algorithm::basic + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Default case if none of the conditions match + return radix_sort_onesweep_config_params_base(); +} + +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 22}, + 6, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, + 4, + block_radix_rank_algorithm::basic + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Default case if none of the conditions match + return radix_sort_onesweep_config_params_base(); +} + +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 22}, + 6, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, + 4, + block_radix_rank_algorithm::basic + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Default case if none of the conditions match + return radix_sort_onesweep_config_params_base(); +} + +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + return radix_sort_onesweep_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using radix_sort_onesweep_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_reduce.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_reduce.hpp index 6a8d99cdd52..0ebebb30cd5 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_reduce.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_reduce.hpp @@ -40,702 +40,659 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_reduce_config : default_reduce_config_base::type -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 1, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<128, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::gfx1200), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::gfx1200), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<128, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<128, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<128, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<128, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<128, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<128, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<128, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<64, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<128, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<64, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<128, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<128, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<128, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<128, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<128, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 1}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {128, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {128, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {128, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {128, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {128, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {128, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {128, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {128, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {64, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {128, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {64, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {128, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {128, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + return reduce_config_picker< + comp_target, + key_type>(); +} + +// All the existing configs should be auto generated +using reduce_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_reduce_by_key.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_reduce_by_key.hpp index 7c9a06431bc..6b346e8017c 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_reduce_by_key.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_reduce_by_key.hpp @@ -38,5561 +38,3752 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_reduce_by_key_config : default_reduce_by_key_config_base::type -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 4, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 4, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 6, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 4, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 6, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 11, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 4, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 6, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 9, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 13, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 13, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<128, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<128, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<128, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<128, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<128, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<128, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<384, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<384, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<128, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<192, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<128, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 6, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 4, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<512, - 4, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 6, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 4, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<512, - 6, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 4, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<128, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<128, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<128, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<128, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<128, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<128, - 6, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<128, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<128, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<128, - 6, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = int64_t, value_type = double + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = float + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 4}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 4}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = double + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = float + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = double + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = float + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = double + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = float + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = double + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = float + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 4}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {128, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = int64_t, value_type = double + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = float + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int64_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int, value_type = double + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = float + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int, value_type = int + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = double + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = float + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = short, value_type = int + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = double + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = float + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = double + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 6}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = float + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 4}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {512, 4}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 6}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 4}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {512, 6}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = int64_t, value_type = double + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 6}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = float + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 4}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 6}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 11}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 4}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 6}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = double + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = float + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = double + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 9}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = float + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = double + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = float + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = int64_t, value_type = double + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = float + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = double + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = float + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 13}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 13}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = double + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = float + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {128, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = double + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = float + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = double + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = float + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {384, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {384, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = int64_t, value_type = double + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {128, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = float + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {128, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {128, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {128, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {128, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = double + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = float + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = double + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = float + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = double + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = float + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = double + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = float + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {128, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {192, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {128, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = int64_t, value_type = double + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = float + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = double + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = float + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = double + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = float + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = double + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = float + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = double + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {128, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = float + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {128, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {128, 6}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = int64_t, value_type = double + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = float + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = double + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = float + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = double + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = float + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = double + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = float + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = double + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {128, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = float + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {128, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = rocprim::int128_t, value_type = double + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = float + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = double + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = float + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = double + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = float + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = double + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = float + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = double + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = float + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + return reduce_by_key_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using reduce_by_key_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_run_length_encode.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_run_length_encode.hpp index 0f8c32c4cee..19f14d24b31 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_run_length_encode.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_run_length_encode.hpp @@ -38,1075 +38,786 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_trivial_runs_config : default_reduce_by_key_config_base::type -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_by_key_config<256, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_by_key_config<384, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_by_key_config<192, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_by_key_config_params{ + {384, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + return run_length_encode_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using run_length_encode_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_run_length_encode_non_trivial.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_run_length_encode_non_trivial.hpp index e4042269f6c..731b6571442 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_run_length_encode_non_trivial.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_run_length_encode_non_trivial.hpp @@ -40,933 +40,722 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_non_trivial_runs_config : default_non_trivial_runs_config_base::type -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : non_trivial_runs_config<64, - 16, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : non_trivial_runs_config<64, - 8, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<128, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<512, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<512, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : non_trivial_runs_config<64, - 8, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<128, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<64, - 16, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<64, - 64, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : non_trivial_runs_config<64, - 16, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<64, - 16, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : non_trivial_runs_config<64, - 8, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<64, - 16, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<64, - 64, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : non_trivial_runs_config<64, - 16, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<64, - 16, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<512, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<512, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return non_trivial_runs_config_params{ + {64, 16}, + ::rocprim::block_load_method::block_load_warp_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return non_trivial_runs_config_params_base(); +} + +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return non_trivial_runs_config_params{ + {64, 8}, + ::rocprim::block_load_method::block_load_warp_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return non_trivial_runs_config_params_base(); +} + +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return non_trivial_runs_config_params{ + {128, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return non_trivial_runs_config_params_base(); +} + +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {512, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {512, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return non_trivial_runs_config_params_base(); +} + +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return non_trivial_runs_config_params{ + {64, 8}, + ::rocprim::block_load_method::block_load_warp_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return non_trivial_runs_config_params{ + {128, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return non_trivial_runs_config_params_base(); +} + +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {64, 16}, + ::rocprim::block_load_method::block_load_warp_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return non_trivial_runs_config_params{ + {64, 64}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return non_trivial_runs_config_params{ + {64, 16}, + ::rocprim::block_load_method::block_load_warp_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {64, 16}, + ::rocprim::block_load_method::block_load_warp_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return non_trivial_runs_config_params_base(); +} + +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return non_trivial_runs_config_params{ + {64, 8}, + ::rocprim::block_load_method::block_load_warp_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return non_trivial_runs_config_params_base(); +} + +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {512, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {512, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return non_trivial_runs_config_params_base(); +} + +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + return run_length_encode_non_trivial_config_picker< + comp_target, + key_type>(); +} + +// All the existing configs should be auto generated +using run_length_encode_non_trivial_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_scan.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_scan.hpp index 457e83e157f..5ff7018c330 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_scan.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_scan.hpp @@ -40,986 +40,792 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_scan_config : default_scan_config_base::type -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<64, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<64, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_scan_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_config<256, - 2, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<64, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<64, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<64, - 22, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_scan_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<64, - 22, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::gfx1200), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::gfx1200), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_scan_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_config<64, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_scan_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_scan_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_config<64, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<128, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_scan_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_scan_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_scan_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_config<256, - 11, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {64, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return scan_config_params{ + {64, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_config_params{ + {256, 2}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {64, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_config_params{ + {64, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_config_params_base(); +} + +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return scan_config_params{ + {64, 22}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_config_params{ + {64, 22}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_config_params_base(); +} + +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return scan_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_config_params_base(); +} + +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return scan_config_params{ + {256, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_config_params{ + {256, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_config_params{ + {256, 11}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_config_params_base(); +} + +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {128, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_config_params{ + {64, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_config_params_base(); +} + +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_config_params{ + {128, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_config_params_base(); +} + +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {128, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return scan_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_config_params{ + {64, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {128, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {128, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Default case if none of the conditions match + return scan_config_params_base(); +} + +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_config_params_base(); +} + +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + return scan_config_picker, + value_type>(); +} + +// All the existing configs should be auto generated +using scan_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp index 8436e810f2c..800e48e6862 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp @@ -40,5437 +40,3680 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_scan_by_key_config : default_scan_by_key_config_base::type -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 11, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 3, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 2, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 3, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 13, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 3, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 3, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 3, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<64, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<128, - 17, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 22, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 17, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 11, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 17, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 22, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 17, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 22, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 13, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 3, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 22, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 3, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<128, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 11, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<128, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 11, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<128, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 13, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 11}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 3}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 2}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 9}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 3}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 19}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {128, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 13}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 3}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 3}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 3}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 19}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 9}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {64, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {128, 17}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 22}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_by_key_config_params_base(); +} + +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 17}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 11}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 17}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 22}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 17}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {128, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_by_key_config_params_base(); +} + +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Default case if none of the conditions match + return scan_by_key_config_params_base(); +} + +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 19}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 22}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 13}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Default case if none of the conditions match + return scan_by_key_config_params_base(); +} + +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {128, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {128, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 19}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 9}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {128, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {128, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Default case if none of the conditions match + return scan_by_key_config_params_base(); +} + +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 19}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 19}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 3}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 9}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 19}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 9}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 22}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 3}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {128, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Default case if none of the conditions match + return scan_by_key_config_params_base(); +} + +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {128, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 19}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Default case if none of the conditions match + return scan_by_key_config_params_base(); +} + +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 11}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 9}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {128, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 9}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {128, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 11}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 9}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {128, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 13}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_by_key_config_params_base(); +} + +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + return scan_by_key_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using scan_by_key_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_search_n.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_search_n.hpp index 30a3be07d2f..5906b95f35a 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_search_n.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_search_n.hpp @@ -40,618 +40,587 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_search_n_config : default_search_n_config_base::type -{}; - -// Based on data_type = double -template -struct default_search_n_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : search_n_config<512, 4, 8> -{}; - -// Based on data_type = float -template -struct default_search_n_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : search_n_config<256, 4, 8> -{}; - -// Based on data_type = rocprim::half -template -struct default_search_n_config(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> - : search_n_config<1024, 8, 16> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_search_n_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : search_n_config<1024, 1, 8> -{}; - -// Based on data_type = int64_t -template -struct default_search_n_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : search_n_config<512, 4, 16> -{}; - -// Based on data_type = int -template -struct default_search_n_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : search_n_config<256, 4, 8> -{}; - -// Based on data_type = short -template -struct default_search_n_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : search_n_config<1024, 8, 16> -{}; - -// Based on data_type = int8_t -template -struct default_search_n_config(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> - : search_n_config<1024, 16, 16> -{}; - -// Based on data_type = double -template -struct default_search_n_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : search_n_config<256, 4, 8> -{}; - -// Based on data_type = float -template -struct default_search_n_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : search_n_config<256, 4, 8> -{}; - -// Based on data_type = rocprim::half -template -struct default_search_n_config(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> - : search_n_config<1024, 16, 16> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_search_n_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : search_n_config<256, 1, 8> -{}; - -// Based on data_type = int64_t -template -struct default_search_n_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : search_n_config<256, 4, 8> -{}; - -// Based on data_type = int -template -struct default_search_n_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : search_n_config<512, 4, 12> -{}; - -// Based on data_type = short -template -struct default_search_n_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : search_n_config<1024, 16, 16> -{}; - -// Based on data_type = int8_t -template -struct default_search_n_config(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> - : search_n_config<1024, 16, 16> -{}; - -// Based on data_type = double -template -struct default_search_n_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : search_n_config<128, 2, 8> -{}; - -// Based on data_type = float -template -struct default_search_n_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : search_n_config<256, 2, 8> -{}; - -// Based on data_type = rocprim::half -template -struct default_search_n_config(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> - : search_n_config<256, 4, 8> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_search_n_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : search_n_config<256, 1, 4> -{}; - -// Based on data_type = int64_t -template -struct default_search_n_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : search_n_config<128, 2, 4> -{}; - -// Based on data_type = int -template -struct default_search_n_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : search_n_config<256, 2, 8> -{}; - -// Based on data_type = short -template -struct default_search_n_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : search_n_config<256, 4, 8> -{}; - -// Based on data_type = int8_t -template -struct default_search_n_config(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> - : search_n_config<256, 4, 8> -{}; - -// Based on data_type = double -template -struct default_search_n_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : search_n_config<1024, 2, 4> -{}; - -// Based on data_type = float -template -struct default_search_n_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : search_n_config<256, 2, 8> -{}; - -// Based on data_type = rocprim::half -template -struct default_search_n_config(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> - : search_n_config<256, 4, 12> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_search_n_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : search_n_config<512, 1, 12> -{}; - -// Based on data_type = int64_t -template -struct default_search_n_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : search_n_config<1024, 2, 12> -{}; - -// Based on data_type = int -template -struct default_search_n_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : search_n_config<256, 2, 8> -{}; - -// Based on data_type = short -template -struct default_search_n_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : search_n_config<1024, 4, 8> -{}; - -// Based on data_type = int8_t -template -struct default_search_n_config(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> - : search_n_config<512, 4, 8> -{}; - -// Based on data_type = double -template -struct default_search_n_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : search_n_config<128, 2, 8> -{}; - -// Based on data_type = float -template -struct default_search_n_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : search_n_config<256, 2, 8> -{}; - -// Based on data_type = rocprim::half -template -struct default_search_n_config(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> - : search_n_config<128, 4, 8> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_search_n_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : search_n_config<128, 1, 4> -{}; - -// Based on data_type = int64_t -template -struct default_search_n_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : search_n_config<128, 2, 8> -{}; - -// Based on data_type = int -template -struct default_search_n_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : search_n_config<256, 2, 8> -{}; - -// Based on data_type = short -template -struct default_search_n_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : search_n_config<128, 4, 8> -{}; - -// Based on data_type = int8_t -template -struct default_search_n_config(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> - : search_n_config<128, 4, 8> -{}; - -// Based on data_type = double -template -struct default_search_n_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : search_n_config<1024, 2, 4> -{}; - -// Based on data_type = float -template -struct default_search_n_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : search_n_config<256, 2, 4> -{}; - -// Based on data_type = rocprim::half -template -struct default_search_n_config(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> - : search_n_config<256, 4, 12> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_search_n_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : search_n_config<512, 1, 12> -{}; - -// Based on data_type = int64_t -template -struct default_search_n_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : search_n_config<1024, 2, 16> -{}; - -// Based on data_type = int -template -struct default_search_n_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : search_n_config<256, 2, 8> -{}; - -// Based on data_type = short -template -struct default_search_n_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : search_n_config<1024, 4, 8> -{}; - -// Based on data_type = int8_t -template -struct default_search_n_config(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> - : search_n_config<512, 4, 8> -{}; - -// Based on data_type = double -template -struct default_search_n_config< - static_cast(target_arch::gfx1201), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : search_n_config<64, 2, 4> -{}; - -// Based on data_type = float -template -struct default_search_n_config< - static_cast(target_arch::gfx1201), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : search_n_config<128, 4, 12> -{}; - -// Based on data_type = rocprim::half -template -struct default_search_n_config(target_arch::gfx1201), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> - : search_n_config<128, 8, 12> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_search_n_config< - static_cast(target_arch::gfx1201), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : search_n_config<64, 1, 8> -{}; - -// Based on data_type = int64_t -template -struct default_search_n_config< - static_cast(target_arch::gfx1201), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : search_n_config<64, 2, 12> -{}; - -// Based on data_type = int -template -struct default_search_n_config< - static_cast(target_arch::gfx1201), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : search_n_config<64, 4, 4> -{}; - -// Based on data_type = short -template -struct default_search_n_config< - static_cast(target_arch::gfx1201), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : search_n_config<64, 8, 8> -{}; - -// Based on data_type = int8_t -template -struct default_search_n_config(target_arch::gfx1201), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> - : search_n_config<64, 16, 4> -{}; - -// Based on data_type = double -template -struct default_search_n_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : search_n_config<256, 2, 8> -{}; - -// Based on data_type = float -template -struct default_search_n_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : search_n_config<128, 4, 12> -{}; - -// Based on data_type = rocprim::half -template -struct default_search_n_config(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> - : search_n_config<256, 4, 16> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_search_n_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : search_n_config<128, 2, 12> -{}; - -// Based on data_type = int64_t -template -struct default_search_n_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : search_n_config<256, 2, 12> -{}; - -// Based on data_type = int -template -struct default_search_n_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : search_n_config<128, 4, 16> -{}; - -// Based on data_type = short -template -struct default_search_n_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : search_n_config<256, 4, 8> -{}; - -// Based on data_type = int8_t -template -struct default_search_n_config(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> - : search_n_config<1024, 8, 8> -{}; +template +constexpr auto search_n_config_picker() -> std::enable_if_t< + std::is_same>::value, + search_n_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {512, 4}, + 8 + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return search_n_config_params{ + {1024, 8}, + 16 + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return search_n_config_params{ + {1024, 1}, + 8 + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {512, 4}, + 16 + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return search_n_config_params{ + {1024, 8}, + 16 + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return search_n_config_params{ + {1024, 16}, + 16 + }; + } + // Default case if none of the conditions match + return search_n_config_params_base(); +} + +template +constexpr auto search_n_config_picker() -> std::enable_if_t< + std::is_same>::value, + search_n_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return search_n_config_params{ + {1024, 16}, + 16 + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return search_n_config_params{ + {256, 1}, + 8 + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {512, 4}, + 12 + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return search_n_config_params{ + {1024, 16}, + 16 + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return search_n_config_params{ + {1024, 16}, + 16 + }; + } + // Default case if none of the conditions match + return search_n_config_params_base(); +} + +template +constexpr auto search_n_config_picker() -> std::enable_if_t< + std::is_same>::value, + search_n_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {128, 2}, + 8 + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 2}, + 8 + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return search_n_config_params{ + {256, 1}, + 4 + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {128, 2}, + 4 + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 2}, + 8 + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Default case if none of the conditions match + return search_n_config_params_base(); +} + +template +constexpr auto search_n_config_picker() -> std::enable_if_t< + std::is_same>::value, + search_n_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {1024, 2}, + 4 + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 2}, + 8 + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return search_n_config_params{ + {256, 4}, + 12 + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return search_n_config_params{ + {512, 1}, + 12 + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {1024, 2}, + 12 + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 2}, + 8 + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return search_n_config_params{ + {1024, 4}, + 8 + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return search_n_config_params{ + {512, 4}, + 8 + }; + } + // Default case if none of the conditions match + return search_n_config_params_base(); +} + +template +constexpr auto search_n_config_picker() -> std::enable_if_t< + std::is_same>::value, + search_n_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {128, 2}, + 8 + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 2}, + 8 + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return search_n_config_params{ + {128, 4}, + 8 + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return search_n_config_params{ + {128, 1}, + 4 + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {128, 2}, + 8 + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 2}, + 8 + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return search_n_config_params{ + {128, 4}, + 8 + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return search_n_config_params{ + {128, 4}, + 8 + }; + } + // Default case if none of the conditions match + return search_n_config_params_base(); +} + +template +constexpr auto search_n_config_picker() -> std::enable_if_t< + std::is_same>::value, + search_n_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {64, 2}, + 4 + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {128, 4}, + 12 + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return search_n_config_params{ + {128, 8}, + 12 + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return search_n_config_params{ + {64, 1}, + 8 + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {64, 2}, + 12 + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {64, 4}, + 4 + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return search_n_config_params{ + {64, 8}, + 8 + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return search_n_config_params{ + {64, 16}, + 4 + }; + } + // Default case if none of the conditions match + return search_n_config_params_base(); +} + +template +constexpr auto search_n_config_picker() -> std::enable_if_t< + std::is_same>::value, + search_n_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {256, 2}, + 8 + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {128, 4}, + 12 + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return search_n_config_params{ + {256, 4}, + 16 + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return search_n_config_params{ + {128, 2}, + 12 + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {256, 2}, + 12 + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {128, 4}, + 16 + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return search_n_config_params{ + {1024, 8}, + 8 + }; + } + // Default case if none of the conditions match + return search_n_config_params_base(); +} + +template +constexpr auto search_n_config_picker() -> std::enable_if_t< + std::is_same>::value, + search_n_config_params> +{ + return search_n_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using search_n_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp index 628fab9a2fa..43410d17367 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp @@ -40,6966 +40,4153 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_segmented_radix_sort_config : default_segmented_radix_sort_config_base<6>::type -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<128, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<128, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<128, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<128, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Default case if none of the conditions match + return segmented_radix_sort_config_params_base(); +} + +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{128, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{128, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{128, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Default case if none of the conditions match + return segmented_radix_sort_config_params_base(); +} + +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Default case if none of the conditions match + return segmented_radix_sort_config_params_base(); +} + +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Default case if none of the conditions match + return segmented_radix_sort_config_params_base(); +} + +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{128, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Default case if none of the conditions match + return segmented_radix_sort_config_params_base(); +} + +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Default case if none of the conditions match + return segmented_radix_sort_config_params_base(); +} + +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Default case if none of the conditions match + return segmented_radix_sort_config_params_base(); +} + +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Default case if none of the conditions match + return segmented_radix_sort_config_params_base(); +} + +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + return segmented_radix_sort_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using segmented_radix_sort_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_segmented_reduce.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_segmented_reduce.hpp index b04326b6ac8..9f6b159fa05 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_segmented_reduce.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_segmented_reduce.hpp @@ -38,650 +38,587 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_segmented_reduce_config : default_reduce_config_base::type -{}; - -// Based on key_type = double -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<128, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_segmented_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_segmented_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_segmented_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<128, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_segmented_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_segmented_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; +template +constexpr auto segmented_reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto segmented_reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto segmented_reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto segmented_reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {128, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto segmented_reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto segmented_reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto segmented_reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto segmented_reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + return segmented_reduce_config_picker< + comp_target, + key_type>(); +} + +// All the existing configs should be auto generated +using segmented_reduce_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_flag.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_flag.hpp index 7ab24f63a9a..95d57ce9a55 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_flag.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_flag.hpp @@ -40,608 +40,523 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_select_flag_config : default_partition_config_base::type -{}; - -// Based on data_type = double -template -struct default_select_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 6> -{}; - -// Based on data_type = float -template -struct default_select_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 6> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 16> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 6> -{}; - -// Based on data_type = int -template -struct default_select_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 6> -{}; - -// Based on data_type = short -template -struct default_select_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 24> -{}; - -// Based on data_type = double -template -struct default_select_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 6> -{}; - -// Based on data_type = float -template -struct default_select_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<128, 12> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 20> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 6> -{}; - -// Based on data_type = int -template -struct default_select_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<128, 12> -{}; - -// Based on data_type = short -template -struct default_select_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<128, 24> -{}; - -// Based on data_type = int8_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_select_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 5> -{}; - -// Based on data_type = float -template -struct default_select_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 7> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 12> -{}; - -// Based on data_type = int64_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 4> -{}; - -// Based on data_type = int -template -struct default_select_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<192, 8> -{}; - -// Based on data_type = short -template -struct default_select_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 12> -{}; - -// Based on data_type = int8_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<384, 28> -{}; - -// Based on data_type = double -template -struct default_select_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 7> -{}; - -// Based on data_type = float -template -struct default_select_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 13> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 16> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_select_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 13> -{}; - -// Based on data_type = short -template -struct default_select_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_select_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_select_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 12> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 16> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_select_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 12> -{}; - -// Based on data_type = short -template -struct default_select_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_select_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<192, 4> -{}; - -// Based on data_type = float -template -struct default_select_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 7> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 20> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<192, 4> -{}; - -// Based on data_type = int -template -struct default_select_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 7> -{}; - -// Based on data_type = short -template -struct default_select_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 14> -{}; - -// Based on data_type = int8_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_select_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_select_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 12> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 16> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_select_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 12> -{}; - -// Based on data_type = short -template -struct default_select_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_select_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_select_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_select_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 28> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_select_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = short -template -struct default_select_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 28> -{}; - -// Based on data_type = int8_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 24> -{}; +template +constexpr auto select_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 20} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {128, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 5} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {192, 8} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {384, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 12} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 12} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 28} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 28} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return select_flag_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using select_flag_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicate.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicate.hpp index 9efc5124b5e..c43ce0896f4 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicate.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicate.hpp @@ -40,608 +40,523 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_select_predicate_config : default_partition_config_base::type -{}; - -// Based on data_type = double -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = float -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 6> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 30> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 8> -{}; - -// Based on data_type = short -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 30> -{}; - -// Based on data_type = int8_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 24> -{}; - -// Based on data_type = double -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<384, 6> -{}; - -// Based on data_type = float -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<128, 14> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<192, 22> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<384, 6> -{}; - -// Based on data_type = int -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<128, 14> -{}; - -// Based on data_type = short -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<128, 30> -{}; - -// Based on data_type = int8_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 4> -{}; - -// Based on data_type = float -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 7> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<384, 18> -{}; - -// Based on data_type = int64_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 4> -{}; - -// Based on data_type = int -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<384, 7> -{}; - -// Based on data_type = short -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<384, 18> -{}; - -// Based on data_type = int8_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<384, 28> -{}; - -// Based on data_type = double -template -struct default_select_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 7> -{}; - -// Based on data_type = float -template -struct default_select_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 24> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_select_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<192, 15> -{}; - -// Based on data_type = short -template -struct default_select_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 26> -{}; - -// Based on data_type = int8_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_select_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_select_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 24> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_select_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 14> -{}; - -// Based on data_type = short -template -struct default_select_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 26> -{}; - -// Based on data_type = int8_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_select_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 6> -{}; - -// Based on data_type = float -template -struct default_select_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 20> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 5> -{}; - -// Based on data_type = int -template -struct default_select_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 10> -{}; - -// Based on data_type = short -template -struct default_select_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 20> -{}; - -// Based on data_type = int8_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_select_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_select_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 24> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_select_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_select_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 14> -{}; - -// Based on data_type = short -template -struct default_select_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 26> -{}; - -// Based on data_type = int8_t -template -struct default_select_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_select_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_select_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 30> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_select_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = short -template -struct default_select_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 30> -{}; - -// Based on data_type = int8_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 24> -{}; +template +constexpr auto select_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 30} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 14} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {192, 22} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 14} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {128, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {384, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {192, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 26} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 26} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 5} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 10} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return select_predicate_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using select_predicate_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp index 1b8faafb244..cf346fd84d6 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp @@ -40,3511 +40,2401 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_select_predicated_flag_config : default_partition_config_base::type -{}; - -// Based on data_type = double, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<256, 4> -{}; - -// Based on data_type = double, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = double, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 4> -{}; - -// Based on data_type = double, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 4> -{}; - -// Based on data_type = double, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> - : select_config<512, 4> -{}; - -// Based on data_type = float, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<256, 4> -{}; - -// Based on data_type = float, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = float, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 4> -{}; - -// Based on data_type = float, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 4> -{}; - -// Based on data_type = float, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> - : select_config<512, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = rocprim::half, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 8> -{}; - -// Based on data_type = rocprim::half, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<192, 25> -{}; - -// Based on data_type = rocprim::half, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 1))>> : select_config<512, 16> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 1))>> : select_config<256, 4> -{}; - -// Based on data_type = int64_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 1))>> : select_config<512, 4> -{}; - -// Based on data_type = int, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 4> -{}; - -// Based on data_type = int, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 4> -{}; - -// Based on data_type = int, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 1))>> : select_config<512, 4> -{}; - -// Based on data_type = short, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = short, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = short, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 8> -{}; - -// Based on data_type = short, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 8> -{}; - -// Based on data_type = short, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 1))>> : select_config<512, 16> -{}; - -// Based on data_type = int8_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<256, 4> -{}; - -// Based on data_type = int8_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = int8_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 8> -{}; - -// Based on data_type = int8_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 16> -{}; - -// Based on data_type = int8_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> - : select_config<512, 16> -{}; - -// Based on data_type = double, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<512, 4> -{}; - -// Based on data_type = double, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = double, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<192, 4> -{}; - -// Based on data_type = double, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 8> -{}; - -// Based on data_type = double, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> - : select_config<512, 8> -{}; - -// Based on data_type = float, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<512, 4> -{}; - -// Based on data_type = float, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 8> -{}; - -// Based on data_type = float, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<192, 6> -{}; - -// Based on data_type = float, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<192, 8> -{}; - -// Based on data_type = float, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> - : select_config<192, 12> -{}; - -// Based on data_type = rocprim::half, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = rocprim::half, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 16> -{}; - -// Based on data_type = rocprim::half, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 16> -{}; - -// Based on data_type = rocprim::half, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 1))>> : select_config<192, 22> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 1))>> : select_config<512, 4> -{}; - -// Based on data_type = int64_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 8> -{}; - -// Based on data_type = int64_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 1))>> : select_config<512, 8> -{}; - -// Based on data_type = int, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = int, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<192, 6> -{}; - -// Based on data_type = int, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<192, 8> -{}; - -// Based on data_type = int, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 1))>> : select_config<192, 12> -{}; - -// Based on data_type = short, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = short, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = short, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 16> -{}; - -// Based on data_type = short, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<384, 18> -{}; - -// Based on data_type = short, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 1))>> : select_config<512, 20> -{}; - -// Based on data_type = int8_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<512, 4> -{}; - -// Based on data_type = int8_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 8> -{}; - -// Based on data_type = int8_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 16> -{}; - -// Based on data_type = int8_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 20> -{}; - -// Based on data_type = int8_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> - : select_config<512, 20> -{}; - -// Based on data_type = double, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<192, 4> -{}; - -// Based on data_type = double, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<256, 4> -{}; - -// Based on data_type = double, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<256, 4> -{}; - -// Based on data_type = double, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> - : select_config<256, 4> -{}; - -// Based on data_type = float, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<384, 4> -{}; - -// Based on data_type = float, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<384, 4> -{}; - -// Based on data_type = float, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<384, 6> -{}; - -// Based on data_type = float, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> - : select_config<384, 7> -{}; - -// Based on data_type = rocprim::half, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<384, 8> -{}; - -// Based on data_type = rocprim::half, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<384, 12> -{}; - -// Based on data_type = rocprim::half, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 1))>> : select_config<384, 18> -{}; - -// Based on data_type = int64_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 1))>> : select_config<256, 4> -{}; - -// Based on data_type = int, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<384, 4> -{}; - -// Based on data_type = int, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<384, 4> -{}; - -// Based on data_type = int, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<384, 6> -{}; - -// Based on data_type = int, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 1))>> : select_config<384, 7> -{}; - -// Based on data_type = short, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 4> -{}; - -// Based on data_type = short, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<384, 8> -{}; - -// Based on data_type = short, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<384, 12> -{}; - -// Based on data_type = short, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 1))>> : select_config<384, 18> -{}; - -// Based on data_type = int8_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<256, 8> -{}; - -// Based on data_type = int8_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 16> -{}; - -// Based on data_type = int8_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 16> -{}; - -// Based on data_type = int8_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = double, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = double, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<192, 7> -{}; - -// Based on data_type = double, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<192, 7> -{}; - -// Based on data_type = double, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<256, 7> -{}; - -// Based on data_type = double, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> - : select_config<256, 7> -{}; - -// Based on data_type = float, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = float, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<256, 6> -{}; - -// Based on data_type = float, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<256, 11> -{}; - -// Based on data_type = float, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<256, 10> -{}; - -// Based on data_type = float, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> - : select_config<256, 11> -{}; - -// Based on data_type = rocprim::half, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = rocprim::half, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on data_type = rocprim::half, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 18> -{}; - -// Based on data_type = rocprim::half, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 1))>> : select_config<256, 18> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<192, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<192, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 1))>> : select_config<256, 3> -{}; - -// Based on data_type = int64_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<192, 7> -{}; - -// Based on data_type = int64_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<192, 7> -{}; - -// Based on data_type = int64_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 7> -{}; - -// Based on data_type = int64_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 1))>> : select_config<256, 7> -{}; - -// Based on data_type = int, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = int, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 11> -{}; - -// Based on data_type = int, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 10> -{}; - -// Based on data_type = int, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 1))>> : select_config<256, 11> -{}; - -// Based on data_type = short, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = short, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = short, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on data_type = short, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 18> -{}; - -// Based on data_type = short, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 1))>> : select_config<256, 18> -{}; - -// Based on data_type = int8_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = int8_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<256, 6> -{}; - -// Based on data_type = int8_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<256, 11> -{}; - -// Based on data_type = int8_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<256, 24> -{}; - -// Based on data_type = int8_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> - : select_config<192, 24> -{}; - -// Based on data_type = double, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = double, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<128, 6> -{}; - -// Based on data_type = double, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<128, 6> -{}; - -// Based on data_type = double, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<128, 6> -{}; - -// Based on data_type = double, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> - : select_config<128, 7> -{}; - -// Based on data_type = float, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = float, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<256, 6> -{}; - -// Based on data_type = float, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<128, 11> -{}; - -// Based on data_type = float, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<256, 10> -{}; - -// Based on data_type = float, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> - : select_config<256, 11> -{}; - -// Based on data_type = rocprim::half, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = rocprim::half, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on data_type = rocprim::half, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 18> -{}; - -// Based on data_type = rocprim::half, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 1))>> : select_config<128, 4> -{}; - -// Based on data_type = int64_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<128, 6> -{}; - -// Based on data_type = int64_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<128, 6> -{}; - -// Based on data_type = int64_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<128, 6> -{}; - -// Based on data_type = int64_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 1))>> : select_config<128, 7> -{}; - -// Based on data_type = int, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = int, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<128, 11> -{}; - -// Based on data_type = int, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 10> -{}; - -// Based on data_type = int, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 1))>> : select_config<256, 11> -{}; - -// Based on data_type = short, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = short, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = short, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on data_type = short, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 18> -{}; - -// Based on data_type = short, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 1))>> : select_config<256, 18> -{}; - -// Based on data_type = int8_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = int8_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<256, 6> -{}; - -// Based on data_type = int8_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<256, 11> -{}; - -// Based on data_type = int8_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<256, 24> -{}; - -// Based on data_type = int8_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = double, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<128, 4> -{}; - -// Based on data_type = double, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<192, 4> -{}; - -// Based on data_type = double, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 5> -{}; - -// Based on data_type = double, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<192, 4> -{}; - -// Based on data_type = double, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> - : select_config<192, 4> -{}; - -// Based on data_type = float, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = float, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<192, 5> -{}; - -// Based on data_type = float, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<256, 6> -{}; - -// Based on data_type = float, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<192, 10> -{}; - -// Based on data_type = float, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> - : select_config<256, 7> -{}; - -// Based on data_type = rocprim::half, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = rocprim::half, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 10> -{}; - -// Based on data_type = rocprim::half, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 20> -{}; - -// Based on data_type = rocprim::half, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 1))>> : select_config<256, 20> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 1))>> : select_config<128, 4> -{}; - -// Based on data_type = int64_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 1))>> : select_config<192, 5> -{}; - -// Based on data_type = int, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<192, 5> -{}; - -// Based on data_type = int, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 6> -{}; - -// Based on data_type = int, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<192, 10> -{}; - -// Based on data_type = int, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 1))>> : select_config<192, 10> -{}; - -// Based on data_type = short, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = short, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = short, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 10> -{}; - -// Based on data_type = short, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 20> -{}; - -// Based on data_type = short, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 1))>> : select_config<256, 20> -{}; - -// Based on data_type = int8_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = int8_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<192, 5> -{}; - -// Based on data_type = int8_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<256, 12> -{}; - -// Based on data_type = int8_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<192, 20> -{}; - -// Based on data_type = int8_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = double, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = double, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<128, 6> -{}; - -// Based on data_type = double, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<128, 6> -{}; - -// Based on data_type = double, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<128, 6> -{}; - -// Based on data_type = double, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> - : select_config<128, 7> -{}; - -// Based on data_type = float, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = float, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<256, 6> -{}; - -// Based on data_type = float, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<128, 11> -{}; - -// Based on data_type = float, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<256, 10> -{}; - -// Based on data_type = float, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> - : select_config<256, 11> -{}; - -// Based on data_type = rocprim::half, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = rocprim::half, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on data_type = rocprim::half, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 18> -{}; - -// Based on data_type = rocprim::half, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 1))>> : select_config<128, 4> -{}; - -// Based on data_type = int64_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<128, 6> -{}; - -// Based on data_type = int64_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<128, 6> -{}; - -// Based on data_type = int64_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<128, 6> -{}; - -// Based on data_type = int64_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 1))>> : select_config<128, 7> -{}; - -// Based on data_type = int, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = int, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<128, 11> -{}; - -// Based on data_type = int, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 10> -{}; - -// Based on data_type = int, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 1))>> : select_config<256, 11> -{}; - -// Based on data_type = short, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = short, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = short, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on data_type = short, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 18> -{}; - -// Based on data_type = short, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 1))>> : select_config<256, 18> -{}; - -// Based on data_type = int8_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = int8_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<256, 6> -{}; - -// Based on data_type = int8_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<256, 11> -{}; - -// Based on data_type = int8_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<256, 24> -{}; - -// Based on data_type = int8_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = double, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<384, 4> -{}; - -// Based on data_type = double, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = double, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 7> -{}; - -// Based on data_type = double, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 7> -{}; - -// Based on data_type = double, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> - : select_config<512, 7> -{}; - -// Based on data_type = float, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<384, 4> -{}; - -// Based on data_type = float, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = float, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 15> -{}; - -// Based on data_type = float, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> - : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = rocprim::half, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 14> -{}; - -// Based on data_type = rocprim::half, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 30> -{}; - -// Based on data_type = rocprim::half, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 1))>> : select_config<512, 28> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<384, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<384, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<384, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 1))>> : select_config<384, 4> -{}; - -// Based on data_type = int64_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int64_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 7> -{}; - -// Based on data_type = int64_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 7> -{}; - -// Based on data_type = int64_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 1))>> : select_config<512, 7> -{}; - -// Based on data_type = int, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = int, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 14> -{}; - -// Based on data_type = int, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 1))>> : select_config<512, 15> -{}; - -// Based on data_type = short, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = short, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = short, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 14> -{}; - -// Based on data_type = short, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 30> -{}; - -// Based on data_type = short, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 1))>> : select_config<512, 28> -{}; - -// Based on data_type = int8_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<384, 4> -{}; - -// Based on data_type = int8_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = int8_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = int8_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 28> -{}; - -// Based on data_type = int8_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> - : select_config<512, 24> -{}; +template +constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = double, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = double, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = double, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = double, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = float, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::half, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = rocprim::half, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = rocprim::half, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = rocprim::half, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 25} + }; + } + // Based on data_type = rocprim::half, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = short, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = short, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = short, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = short, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = short, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = int8_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int8_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int8_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int8_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = int8_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 16} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = double, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = double, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = double, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = double, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = float, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = float, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {192, 6} + }; + } + // Based on data_type = float, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 8} + }; + } + // Based on data_type = float, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {192, 12} + }; + } + // Based on data_type = rocprim::half, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::half, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = rocprim::half, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = rocprim::half, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = rocprim::half, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {192, 22} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int64_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {192, 6} + }; + } + // Based on data_type = int, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 8} + }; + } + // Based on data_type = int, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {192, 12} + }; + } + // Based on data_type = short, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = short, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = short, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = short, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = short, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 20} + }; + } + // Based on data_type = int8_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int8_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int8_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = int8_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 20} + }; + } + // Based on data_type = int8_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 20} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = double, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = double, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = double, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = float, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = float, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = float, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on data_type = float, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = rocprim::half, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = rocprim::half, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on data_type = rocprim::half, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {384, 12} + }; + } + // Based on data_type = rocprim::half, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = int64_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on data_type = int, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = short, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = short, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on data_type = short, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {384, 12} + }; + } + // Based on data_type = short, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = int8_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 8} + }; + } + // Based on data_type = int8_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = int8_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = int8_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = double, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on data_type = double, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on data_type = double, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = double, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = float, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = float, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = float, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = float, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = rocprim::half, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = rocprim::half, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = rocprim::half, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = rocprim::half, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = rocprim::half, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on data_type = int64_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on data_type = int64_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int64_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = int, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = int, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = int, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = short, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = short, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = short, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = short, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = short, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = int8_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int8_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = int8_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = int8_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {192, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = double, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = double, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = double, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = double, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = float, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = float, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = float, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {128, 11} + }; + } + // Based on data_type = float, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = float, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = rocprim::half, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = rocprim::half, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = rocprim::half, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = rocprim::half, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = rocprim::half, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = int64_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = int64_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = int64_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = int, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = int, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {128, 11} + }; + } + // Based on data_type = int, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = int, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = short, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = short, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = short, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = short, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = short, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = int8_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int8_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = int8_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = int8_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = double, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = double, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 5} + }; + } + // Based on data_type = double, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = double, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = float, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = float, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 5} + }; + } + // Based on data_type = float, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = float, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 10} + }; + } + // Based on data_type = float, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = rocprim::half, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = rocprim::half, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = rocprim::half, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = rocprim::half, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::half, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {192, 5} + }; + } + // Based on data_type = int, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 5} + }; + } + // Based on data_type = int, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = int, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 10} + }; + } + // Based on data_type = int, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {192, 10} + }; + } + // Based on data_type = short, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = short, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = short, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = short, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = short, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = int8_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int8_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 5} + }; + } + // Based on data_type = int8_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = int8_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 20} + }; + } + // Based on data_type = int8_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = double, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = double, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = double, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = double, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = float, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = float, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = float, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = rocprim::half, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = rocprim::half, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on data_type = rocprim::half, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = rocprim::half, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 28} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int64_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int64_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int64_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = int, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on data_type = int, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = short, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = short, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on data_type = short, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = short, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 28} + }; + } + // Based on data_type = int8_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int8_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int8_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = int8_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 28} + }; + } + // Based on data_type = int8_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return select_predicated_flag_config_picker< + comp_target, + data_type, + flag_type>(); +} + +// All the existing configs should be auto generated +using select_predicated_flag_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail @@ -3553,4 +2443,4 @@ END_ROCPRIM_NAMESPACE /// @} // end of group primitivesmodule_deviceconfigs -#endif // ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_PREDICATED_FLAG_HPP_ \ No newline at end of file +#endif // ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_PREDICATED_FLAG_HPP_ diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_unique.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_unique.hpp index 6088fe77ea1..e8adf223d76 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_unique.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_unique.hpp @@ -40,608 +40,523 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_select_unique_config : default_partition_config_base::type -{}; - -// Based on data_type = double -template -struct default_select_unique_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 8> -{}; - -// Based on data_type = float -template -struct default_select_unique_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 8> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_unique_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 12> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = int -template -struct default_select_unique_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 8> -{}; - -// Based on data_type = short -template -struct default_select_unique_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<384, 18> -{}; - -// Based on data_type = int8_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<384, 28> -{}; - -// Based on data_type = double -template -struct default_select_unique_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 8> -{}; - -// Based on data_type = float -template -struct default_select_unique_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<128, 14> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_unique_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<384, 22> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = int -template -struct default_select_unique_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<384, 16> -{}; - -// Based on data_type = short -template -struct default_select_unique_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<128, 20> -{}; - -// Based on data_type = int8_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<384, 28> -{}; - -// Based on data_type = double -template -struct default_select_unique_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<384, 8> -{}; - -// Based on data_type = float -template -struct default_select_unique_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 9> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_unique_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 21> -{}; - -// Based on data_type = int64_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<384, 7> -{}; - -// Based on data_type = int -template -struct default_select_unique_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<384, 9> -{}; - -// Based on data_type = short -template -struct default_select_unique_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 21> -{}; - -// Based on data_type = int8_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 20> -{}; - -// Based on data_type = double -template -struct default_select_unique_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 7> -{}; - -// Based on data_type = float -template -struct default_select_unique_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 16> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_unique_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 18> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_select_unique_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 16> -{}; - -// Based on data_type = short -template -struct default_select_unique_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<192, 18> -{}; - -// Based on data_type = int8_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<192, 18> -{}; - -// Based on data_type = double -template -struct default_select_unique_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_select_unique_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 11> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_unique_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 30> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_select_unique_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 11> -{}; - -// Based on data_type = short -template -struct default_select_unique_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 30> -{}; - -// Based on data_type = int8_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_select_unique_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 5> -{}; - -// Based on data_type = float -template -struct default_select_unique_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_unique_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 22> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 5> -{}; - -// Based on data_type = int -template -struct default_select_unique_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 10> -{}; - -// Based on data_type = short -template -struct default_select_unique_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 22> -{}; - -// Based on data_type = int8_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<128, 32> -{}; - -// Based on data_type = double -template -struct default_select_unique_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_select_unique_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 11> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_unique_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 30> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_unique_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_unique_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_select_unique_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 11> -{}; - -// Based on data_type = short -template -struct default_select_unique_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 30> -{}; - -// Based on data_type = int8_t -template -struct default_select_unique_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_select_unique_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_select_unique_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_unique_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 30> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_select_unique_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = short -template -struct default_select_unique_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 30> -{}; - -// Based on data_type = int8_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 32> -{}; +template +constexpr auto select_unique_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 12} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {384, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 14} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {384, 22} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 16} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {128, 20} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {384, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 9} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 21} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 9} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 21} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {192, 18} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {192, 18} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 11} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 30} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 11} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 5} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 22} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 5} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 22} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {128, 32} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 32} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return select_unique_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using select_unique_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_unique_by_key.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_unique_by_key.hpp index 958e0c25f9a..7f0711166b8 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_unique_by_key.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_unique_by_key.hpp @@ -40,3382 +40,2401 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_select_unique_by_key_config : default_partition_config_base::type -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 8> -{}; - -// Based on key_type = double, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<512, 6> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<512, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<512, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<384, 10> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> : select_config<384, 14> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : select_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<512, 6> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<512, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 6> -{}; - -// Based on key_type = short, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<384, 10> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : select_config<384, 14> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<512, 6> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<384, 14> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> : select_config<384, 12> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<512, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 8> -{}; - -// Based on key_type = double, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 8> -{}; - -// Based on key_type = double, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<384, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<128, 7> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<512, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 8> -{}; - -// Based on key_type = float, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<384, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<128, 10> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<128, 12> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<512, 16> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<512, 16> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> : select_config<192, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : select_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<128, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<128, 7> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<512, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 8> -{}; - -// Based on key_type = int, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<384, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<128, 12> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<128, 12> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<128, 11> -{}; - -// Based on key_type = short, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<128, 12> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : select_config<512, 20> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<384, 8> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<128, 13> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<128, 22> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> : select_config<384, 20> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<384, 8> -{}; - -// Based on key_type = double, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<384, 6> -{}; - -// Based on key_type = double, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<384, 6> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<384, 6> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<384, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<384, 7> -{}; - -// Based on key_type = float, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<384, 9> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<384, 12> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<384, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<512, 16> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<512, 14> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> : select_config<512, 14> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<384, 8> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<384, 6> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<384, 6> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<384, 8> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<384, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<384, 7> -{}; - -// Based on key_type = int, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<384, 9> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<384, 12> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<384, 6> -{}; - -// Based on key_type = short, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 16> -{}; - -// Based on key_type = short, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 24> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : select_config<512, 14> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<512, 16> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<512, 32> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> : select_config<384, 16> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<256, 5> -{}; - -// Based on key_type = double, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 6> -{}; - -// Based on key_type = double, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 7> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<192, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<256, 6> -{}; - -// Based on key_type = float, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 13> -{}; - -// Based on key_type = float, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 14> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<192, 16> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<192, 14> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> : select_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 3> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<192, 8> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 6> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 7> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<192, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<256, 6> -{}; - -// Based on key_type = int, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 13> -{}; - -// Based on key_type = int, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 14> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<192, 16> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<256, 6> -{}; - -// Based on key_type = short, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 12> -{}; - -// Based on key_type = short, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 14> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : select_config<192, 16> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<192, 7> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<192, 14> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<256, 14> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> : select_config<192, 17> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 5> -{}; - -// Based on key_type = double, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<128, 6> -{}; - -// Based on key_type = double, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<128, 6> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<192, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 6> -{}; - -// Based on key_type = float, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 13> -{}; - -// Based on key_type = float, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 14> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<256, 16> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<256, 14> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> : select_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 3> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 5> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<128, 6> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 7> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<192, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 6> -{}; - -// Based on key_type = int, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 13> -{}; - -// Based on key_type = int, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 14> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<192, 16> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<256, 6> -{}; - -// Based on key_type = short, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 12> -{}; - -// Based on key_type = short, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 14> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : select_config<256, 14> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<192, 7> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<256, 13> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<256, 14> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> : select_config<192, 17> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<256, 5> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<192, 5> -{}; - -// Based on key_type = float, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 7> -{}; - -// Based on key_type = float, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 10> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<192, 9> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<192, 5> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<192, 10> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 3> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<256, 5> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<192, 5> -{}; - -// Based on key_type = int, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 7> -{}; - -// Based on key_type = int, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 10> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<192, 9> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<192, 5> -{}; - -// Based on key_type = short, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 10> -{}; - -// Based on key_type = short, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 24> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : select_config<256, 24> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<192, 10> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<256, 28> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> : select_config<192, 28> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 5> -{}; - -// Based on key_type = double, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<128, 6> -{}; - -// Based on key_type = double, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<128, 6> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<192, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 6> -{}; - -// Based on key_type = float, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 13> -{}; - -// Based on key_type = float, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 14> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<256, 16> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<256, 14> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> : select_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 3> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 5> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<128, 6> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 7> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<192, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 6> -{}; - -// Based on key_type = int, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 13> -{}; - -// Based on key_type = int, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 14> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<192, 16> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<256, 6> -{}; - -// Based on key_type = short, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 12> -{}; - -// Based on key_type = short, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 14> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : select_config<256, 14> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<192, 7> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<256, 13> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<256, 14> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> : select_config<192, 17> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<384, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 7> -{}; - -// Based on key_type = double, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 7> -{}; - -// Based on key_type = double, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 7> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<512, 7> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<384, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 7> -{}; - -// Based on key_type = float, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 15> -{}; - -// Based on key_type = float, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 14> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<512, 15> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<512, 14> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<512, 22> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> : select_config<512, 24> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<384, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<384, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<384, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<384, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : select_config<384, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<384, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 7> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 7> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 7> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<512, 7> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<384, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 7> -{}; - -// Based on key_type = int, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 15> -{}; - -// Based on key_type = int, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 14> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<512, 15> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<384, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 7> -{}; - -// Based on key_type = short, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 14> -{}; - -// Based on key_type = short, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 22> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : select_config<512, 24> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<512, 24> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> : select_config<512, 24> -{}; +template +constexpr auto select_unique_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 10} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 14} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 10} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 14} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 14} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 12} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {128, 10} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 16} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {128, 8} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {128, 11} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 20} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {128, 13} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {128, 22} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 20} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 9} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 12} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 9} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 12} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 8} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 32} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 16} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 8} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 14} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 16} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 14} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 8} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 8} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 14} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 16} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 16} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 14} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 17} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {128, 5} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 8} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {128, 5} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 8} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 16} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 17} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 5} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 5} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 10} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 9} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 5} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 10} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 5} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 5} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 10} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 9} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 5} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 10} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 10} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 22} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 22} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return select_unique_by_key_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using select_unique_by_key_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail @@ -3424,4 +2443,4 @@ END_ROCPRIM_NAMESPACE /// @} // end of group primitivesmodule_deviceconfigs -#endif // ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_UNIQUE_BY_KEY_HPP_ \ No newline at end of file +#endif // ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_UNIQUE_BY_KEY_HPP_ diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_transform.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_transform.hpp index e513592ce57..12ed641df43 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_transform.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_transform.hpp @@ -40,702 +40,604 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_transform_config : default_transform_config_base::type -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 2> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<256, 1> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<1024, 1> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_config<1024, 1> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 2> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<256, 1> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<1024, 1> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<128, 2> -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 1> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<1024, 1> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<1024, 2> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_config<1024, 1> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 1> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<64, 1> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<1024, 2> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<1024, 4> -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 2> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<512, 1> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<512, 2> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 2> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<512, 1> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<512, 2> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<256, 4> -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<64, 2> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<512, 4> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<512, 2> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<1024, 2> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<1024, 4> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<1024, 8> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<256, 4> -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<1024, 1> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<1024, 2> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<512, 4> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_config<1024, 1> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 1> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<1024, 2> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<512, 4> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<64, 16> -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<128, 1> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<128, 2> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<128, 4> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_config<128, 1> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<128, 1> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<128, 2> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<128, 4> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<128, 8> -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<256, 2> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<1024, 2> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<64, 8> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_config<1024, 1> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<256, 2> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<1024, 2> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<64, 8> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<64, 16> -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<128, 1> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<128, 2> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<128, 4> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_config<128, 1> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<128, 1> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<128, 2> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<128, 4> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<128, 8> -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 4> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<256, 4> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<256, 8> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_config<512, 1> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 2> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<512, 4> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<256, 8> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<1024, 8> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_config<64, 1> -{}; +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {1024, 1} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {1024, 1} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {1024, 1} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Default case if none of the conditions match + return transform_config_params_base(); +} + +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 1} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 1} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {1024, 2} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {1024, 1} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 1} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {1024, 2} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {1024, 4} + }; + } + // Default case if none of the conditions match + return transform_config_params_base(); +} + +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {512, 1} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {512, 2} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {512, 1} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {512, 2} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Default case if none of the conditions match + return transform_config_params_base(); +} + +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {512, 4} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {512, 2} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 4} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {1024, 8} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Default case if none of the conditions match + return transform_config_params_base(); +} + +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 1} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 2} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {512, 4} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {1024, 1} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 1} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 2} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {512, 4} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Default case if none of the conditions match + return transform_config_params_base(); +} + +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {128, 8} + }; + } + // Default case if none of the conditions match + return transform_config_params_base(); +} + +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 2} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {1024, 1} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 2} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Default case if none of the conditions match + return transform_config_params_base(); +} + +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 4} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {512, 1} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {512, 4} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {1024, 8} + }; + } + // Default case if none of the conditions match + return transform_config_params_base(); +} + +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + return transform_config_picker< + comp_target, + value_type>(); +} + +// All the existing configs should be auto generated +using transform_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_transform_pointer.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_transform_pointer.hpp index bf2e4bc3c8c..e936d7581ea 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_transform_pointer.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_transform_pointer.hpp @@ -40,650 +40,594 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_transform_pointer_config : default_transform_pointer_config_base::type -{}; - -// Based on value_type = double -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = float -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int64_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = short -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int8_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = double -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = float -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int64_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = short -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int8_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : transform_pointer_config<1024, 16, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = double -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = float -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : transform_pointer_config<512, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int64_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<512, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = short -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_pointer_config<1024, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int8_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : transform_pointer_config<512, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = double -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = float -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<128, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : transform_pointer_config<128, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int64_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<128, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = short -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_pointer_config<128, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int8_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : transform_pointer_config<128, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = double -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = float -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : transform_pointer_config<1024, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int64_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = short -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_pointer_config<1024, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int8_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : transform_pointer_config<1024, 16, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = double -template -struct default_transform_pointer_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = float -template -struct default_transform_pointer_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<128, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_pointer_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : transform_pointer_config<128, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_pointer_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int64_t -template -struct default_transform_pointer_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int -template -struct default_transform_pointer_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<128, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = short -template -struct default_transform_pointer_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_pointer_config<128, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int8_t -template -struct default_transform_pointer_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : transform_pointer_config<128, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = double -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = float -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<256, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : transform_pointer_config<256, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_pointer_config<256, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int64_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<512, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<256, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = short -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_pointer_config<256, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int8_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : transform_pointer_config<256, 16, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = double -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = float -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_pointer_config<64, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int64_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = short -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_pointer_config<64, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int8_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : transform_pointer_config<1024, 16, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; +template +constexpr auto transform_pointer_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_default + }; + } + // Default case if none of the conditions match + return transform_pointer_config_params_base(); +} + +template +constexpr auto transform_pointer_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_default + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {1024, 16}, + ::rocprim::load_nontemporal + }; + } + // Default case if none of the conditions match + return transform_pointer_config_params_base(); +} + +template +constexpr auto transform_pointer_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 4}, + ::rocprim::load_default + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {512, 4}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 4}, + ::rocprim::load_default + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {1024, 8}, + ::rocprim::load_default + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {512, 8}, + ::rocprim::load_nontemporal + }; + } + // Default case if none of the conditions match + return transform_pointer_config_params_base(); +} + +template +constexpr auto transform_pointer_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {128, 1}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {128, 2}, + ::rocprim::load_default + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {128, 4}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {128, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {128, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {128, 2}, + ::rocprim::load_default + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {128, 4}, + ::rocprim::load_default + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {128, 8}, + ::rocprim::load_nontemporal + }; + } + // Default case if none of the conditions match + return transform_pointer_config_params_base(); +} + +template +constexpr auto transform_pointer_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_default + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 4}, + ::rocprim::load_default + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {1024, 8}, + ::rocprim::load_default + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_default + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 4}, + ::rocprim::load_default + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {1024, 8}, + ::rocprim::load_default + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {1024, 16}, + ::rocprim::load_default + }; + } + // Default case if none of the conditions match + return transform_pointer_config_params_base(); +} + +template +constexpr auto transform_pointer_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {256, 4}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {256, 8}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {256, 1}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {256, 4}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {256, 8}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {256, 16}, + ::rocprim::load_nontemporal + }; + } + // Default case if none of the conditions match + return transform_pointer_config_params_base(); +} + +template +constexpr auto transform_pointer_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {64, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {64, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {1024, 16}, + ::rocprim::load_nontemporal + }; + } + // Default case if none of the conditions match + return transform_pointer_config_params_base(); +} + +template +constexpr auto transform_pointer_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + return transform_pointer_config_picker< + comp_target, + value_type>(); +} + +// All the existing configs should be auto generated +using transform_pointer_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_upper_bound.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_upper_bound.hpp index 1a2df25c825..5c17ef4e819 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_upper_bound.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_upper_bound.hpp @@ -40,4045 +40,2756 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_upper_bound_config : default_binary_search_config_base -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 4> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 4> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<128, 8> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<128, 8> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<128, 8> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 16> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 16> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<64, 16> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 16> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 16> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 16> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 16> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<64, 16> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<64, 16> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 8> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 8> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 8> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 16> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<128, 2> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<128, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 2> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 2> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 2> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 8> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 8> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 16> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<128, 2> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<256, 2> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<64, 2> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 8> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 8> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 16> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<256, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<64, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 16> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<64, 1> -{}; +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 8} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 8} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 8} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 16} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 16} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 16} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 16} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 16} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 2} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + return upper_bound_config_picker< + comp_target, + value_type, + output_type>(); +} + +// All the existing configs should be auto generated +using upper_bound_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp index 75c0fdebf93..abc72f8528c 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp @@ -1104,16 +1104,17 @@ static hipError_t batch_memcpy_func(void* temporary_storage, ROCPRIM_RETURN_ON_ERROR(std::visit( [&](auto use_atomic_block_id) { - using config = wrapped_batch_memcpy_config< - Config, - typename std::iterator_traits::value_type, - IsMemCpy>; + using Selector = batch_memcpy_config_selector; - detail::target_arch target_arch; - ROCPRIM_RETURN_ON_ERROR(detail::host_target_arch(stream, target_arch)); + target_arch target_arch; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - const detail::batch_memcpy_config_params params - = detail::dispatch_target_arch(target_arch); + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); using BufferOffsetType = unsigned int; using BlockOffsetType = unsigned int; @@ -1241,10 +1242,9 @@ static hipError_t batch_memcpy_func(void* temporary_storage, }; auto blev_memcpy_launch_plan - = make_launch_plan(target_arch, - blev_memcpy_kernel); + = make_launch_plan( + current_target, + blev_memcpy_kernel); int blev_occupancy{}; hipError_t error @@ -1320,14 +1320,13 @@ static hipError_t batch_memcpy_func(void* temporary_storage, start_timer(); ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan(target_arch, - non_blev_memcpy_kernel, - batch_memcpy_grid_size, - non_blev_block_size, - 0, - stream)); + execute_launch_plan( + current_target, + non_blev_memcpy_kernel, + batch_memcpy_grid_size, + non_blev_block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("non_blev_memcpy_kernel", num_copies, start); diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp index 585e42fb6eb..2da4f8f8585 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp @@ -96,40 +96,41 @@ constexpr unsigned int merge_sort_block_size(const unsigned int item_scale) // Calculate kernel configurations, such that it will not exceed shared memory maximum template -struct merge_sort_block_sort_config_base +constexpr merge_sort_block_sort_config_params merge_sort_block_sort_config_params_base() { - static constexpr unsigned int item_scale = ::rocprim::max(sizeof(Key), sizeof(Value)); - static constexpr bool use_fallback = merge_sort_block_size(item_scale) * 2 - * merge_sort_items_per_thread(item_scale) * item_scale - <= max_smem_per_block; + constexpr unsigned int item_scale = ::rocprim::max(sizeof(Key), sizeof(Value)); + constexpr bool use_fallback = merge_sort_block_size(item_scale) * 2 + * merge_sort_items_per_thread(item_scale) * item_scale + <= max_smem_per_block; // multiply by 2 to ensure block_sort's items_per_block >= block_merge's items_per_block - static constexpr unsigned int block_size - = use_fallback ? merge_sort_block_size(item_scale) * 2 : 256; - static constexpr unsigned int items_per_thread + constexpr unsigned int block_size = use_fallback ? merge_sort_block_size(item_scale) * 2 : 256; + constexpr unsigned int items_per_thread = use_fallback ? merge_sort_items_per_thread(item_scale) : 1; - using type = merge_sort_block_sort_config; + + return merge_sort_block_sort_config_params{ + {block_size, items_per_thread} + }; }; // Calculate kernel configurations, such that it will not exceed shared memory maximum // No radix_sort_block_sort_params and radix_sort_block_sort_config exist since the only // configuration member is a kernel_config. template -struct radix_sort_block_sort_config_base +constexpr kernel_config_params radix_sort_block_sort_config_params_base() { - static constexpr unsigned int item_scale = ::rocprim::max(sizeof(Key), sizeof(Value)); + constexpr unsigned int item_scale = ::rocprim::max(sizeof(Key), sizeof(Value)); // multiply by 2 to ensure block_sort's items_per_block >= block_merge's items_per_block - static constexpr unsigned int block_size = merge_sort_block_size(item_scale) * 2; - static constexpr unsigned int items_per_thread + constexpr unsigned int block_size = merge_sort_block_size(item_scale) * 2; + constexpr unsigned int items_per_thread = rocprim::min(4u, merge_sort_items_per_thread(item_scale)); - using type = kernel_config; // The items per block should be a power of two, as this is a requirement for the // radix sort merge sort. static_assert(is_power_of_two(block_size * items_per_thread), "Sorted items per block should be a power of two."); + + return kernel_config_params{block_size, items_per_thread}; }; /// \brief Default values are provided by \p merge_sort_block_merge_config_base. @@ -159,23 +160,21 @@ struct merge_sort_block_merge_config : rocprim::detail::merge_sort_block_merge_c }; template -struct merge_sort_block_merge_config_base -{ - static constexpr unsigned int item_scale = ::rocprim::max(sizeof(Key), sizeof(Value)); - static constexpr bool use_fallback = merge_sort_block_size(item_scale) * 2 - * merge_sort_items_per_thread(item_scale) * item_scale - <= max_smem_per_block; - static constexpr unsigned int block_size - = use_fallback ? merge_sort_block_size(item_scale) : 128; - static constexpr unsigned int items_per_thread +constexpr merge_sort_block_merge_config_params merge_sort_block_merge_config_params_base() +{ + constexpr unsigned int item_scale = ::rocprim::max(sizeof(Key), sizeof(Value)); + constexpr bool use_fallback = merge_sort_block_size(item_scale) * 2 + * merge_sort_items_per_thread(item_scale) * item_scale + <= max_smem_per_block; + constexpr unsigned int block_size = use_fallback ? merge_sort_block_size(item_scale) : 128; + constexpr unsigned int items_per_thread = use_fallback ? merge_sort_items_per_thread(item_scale) : 1; - using type = merge_sort_block_merge_config; -}; + return merge_sort_block_merge_config_params{ + {block_size, 1, (1 << 17) + 70000}, + {128, 1}, + {block_size, items_per_thread} + }; +} /// \brief Default values are provided by \p radix_sort_onesweep_config_base. struct radix_sort_onesweep_config_params @@ -224,21 +223,19 @@ struct radix_sort_onesweep_config : detail::radix_sort_onesweep_config_params namespace detail { -struct reduce_config_tag -{}; - // Calculate kernel configurations, such that it will not exceed shared memory maximum template -struct radix_sort_onesweep_config_base +constexpr radix_sort_onesweep_config_params radix_sort_onesweep_config_params_base() { - static constexpr unsigned int item_scale = ::rocprim::max(sizeof(Key), sizeof(Value)); + constexpr unsigned int item_scale = ::rocprim::max(sizeof(Key), sizeof(Value)); + constexpr unsigned int block_size = merge_sort_block_size(item_scale) * 4; - static constexpr unsigned int block_size = merge_sort_block_size(item_scale) * 4; - using type = radix_sort_onesweep_config< - kernel_config<256, 12>, - kernel_config, - 4>; -}; + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{block_size, ::rocprim::max(1u, 65000u / block_size / item_scale)}, + 4 + }; +} struct reduce_config_params { @@ -261,8 +258,6 @@ template struct reduce_config : rocprim::detail::reduce_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::reduce_config_tag; constexpr reduce_config() : rocprim::detail::reduce_config_params{ {BlockSize, ItemsPerThread, SizeLimit}, @@ -274,19 +269,18 @@ namespace detail { template -struct default_reduce_config_base +constexpr reduce_config_params reduce_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); - using type = reduce_config::value, - ::rocprim::max(1u, 16u / item_scale), - ::rocprim::block_reduce_algorithm::using_warp_reduce>; + return reduce_config_params{ + {limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale)}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; }; -struct scan_config_tag -{}; - /// \brief Provides the kernel parameters for exclusive_scan and inclusive_scan based /// on autotuned configurations or user-provided configurations. struct scan_config_params @@ -315,8 +309,6 @@ template struct scan_config : ::rocprim::detail::scan_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::scan_config_tag; #ifndef DOXYGEN_SHOULD_SKIP_THIS // Requirement dictated by init_lookback_scan_state_kernel. static_assert(BlockSize <= ROCPRIM_DEFAULT_MAX_BLOCK_SIZE, @@ -348,20 +340,19 @@ struct scan_config : ::rocprim::detail::scan_config_params namespace detail { -struct scan_by_key_config_tag -{}; - template -struct default_scan_config_base +constexpr scan_config_params scan_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); - using type = scan_config::value, - ::rocprim::max(1u, 16u / item_scale), - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan>; + return scan_config_params{ + {limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale)}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + }; }; /// \brief Provides the kernel parameters for exclusive_scan_by_key and inclusive_scan_by_key based @@ -392,8 +383,6 @@ template struct scan_by_key_config : ::rocprim::detail::scan_by_key_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::scan_by_key_config_tag; #ifndef DOXYGEN_SHOULD_SKIP_THIS // Requirement dictated by init_lookback_scan_state_kernel. static_assert(BlockSize <= ROCPRIM_DEFAULT_MAX_BLOCK_SIZE, @@ -426,66 +415,63 @@ namespace detail { template -struct default_scan_by_key_config_base +constexpr scan_by_key_config_params scan_by_key_config_params_base() { - static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div( - sizeof(Key) + sizeof(Value), 2 * sizeof(int)); + constexpr unsigned int item_scale + = ::rocprim::detail::ceiling_div(sizeof(Key) + sizeof(Value), + 2 * sizeof(int)); - using type = scan_by_key_config< - limit_block_size<256U, sizeof(Key) + sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, - ::rocprim::max(1u, 16u / item_scale), + return scan_by_key_config_params{ + {limit_block_size<256U, sizeof(Key) + sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale)}, ::rocprim::block_load_method::block_load_transpose, ::rocprim::block_store_method::block_store_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan>; + ::rocprim::block_scan_algorithm::using_warp_scan + }; }; -struct transform_config_tag -{}; - struct transform_config_params { - kernel_config_params kernel_config{}; - cache_load_modifier load_type; + kernel_config_params kernel_config = {0, 0}; + cache_load_modifier load_type = load_default; }; } // namespace detail namespace detail { -struct segmented_radix_sort_config_tag -{}; struct warp_sort_config_params { /// \brief Allow the partitioning of batches by size for processing via size-optimized kernels. bool partitioning_allowed = false; /// \brief The number of threads in the logical warp in the small segment processing kernel. - unsigned int logical_warp_size_small = 0; + unsigned int logical_warp_size_small = 1; /// \brief The number of items processed by a thread in the small segment processing kernel. - unsigned int items_per_thread_small = 0; + unsigned int items_per_thread_small = 1; /// \brief The number of threads per block in the small segment processing kernel. - unsigned int block_size_small = 0; + unsigned int block_size_small = 1; /// \brief If the number of segments is at least \p partitioning_threshold, then the segments are partitioned into /// small and large segment groups, and each group is handled by a different, specialized kernel. unsigned int partitioning_threshold = 0; /// \brief The number of threads in the logical warp in the medium segment processing kernel. - unsigned int logical_warp_size_medium = 0; + unsigned int logical_warp_size_medium = 1; /// \brief The number of items processed by a thread in the medium segment processing kernel. - unsigned int items_per_thread_medium = 0; + unsigned int items_per_thread_medium = 1; /// \brief The number of threads per block in the medium segment processing kernel. - unsigned int block_size_medium = 0; + unsigned int block_size_medium = 1; }; struct segmented_radix_sort_config_params { - /// \brief Kernel start parameters. - kernel_config_params kernel_config{}; /// \brief Number of bits in iterations. unsigned int radix_bits = 0; - /// \brief If set to \p true, warp sort can be used to sort the small segments, even if no partitioning happens. - bool enable_unpartitioned_warp_sort = true; + /// \brief Kernel start parameters. + kernel_config_params kernel_config{}; /// \brief Warp sort config params warp_sort_config_params warp_sort_config{}; + /// \brief If set to \p true, warp sort can be used to sort the small segments, even if no partitioning happens. + bool enable_unpartitioned_warp_sort = true; }; } // namespace detail @@ -585,8 +571,6 @@ template struct segmented_radix_sort_config : public detail::segmented_radix_sort_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::segmented_radix_sort_config_tag; #ifndef DOXYGEN_SHOULD_SKIP_THIS /// \brief Number of bits in iterations. @@ -608,17 +592,17 @@ struct segmented_radix_sort_config : public detail::segmented_radix_sort_config_ constexpr segmented_radix_sort_config() : detail::segmented_radix_sort_config_params{ - SortConfig(), RadixBits, - EnableUnpartitionedWarpSort, + SortConfig(), {warp_sort_config::partitioning_allowed, - warp_sort_config::logical_warp_size_small, - warp_sort_config::items_per_thread_small, - warp_sort_config::block_size_small, - warp_sort_config::partitioning_threshold, - warp_sort_config::logical_warp_size_medium, - warp_sort_config::items_per_thread_medium, - warp_sort_config::block_size_medium} + warp_sort_config::logical_warp_size_small, + warp_sort_config::items_per_thread_small, + warp_sort_config::block_size_small, + warp_sort_config::partitioning_threshold, + warp_sort_config::logical_warp_size_medium, + warp_sort_config::items_per_thread_medium, + warp_sort_config::block_size_medium}, + EnableUnpartitionedWarpSort } {} #endif @@ -627,18 +611,17 @@ struct segmented_radix_sort_config : public detail::segmented_radix_sort_config_ namespace detail { /// \brief Default segmented_radix_sort kernel configurations, such that the maximum shared memory is not exceeded. -/// -/// \tparam RadixBits Bits used during the sorting. -/// \tparam ItemsPerThread Items per thread when type Key has size 1. -template -struct default_segmented_radix_sort_config_base -{ - static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div( - sizeof(unsigned int) + sizeof(unsigned int), sizeof(int)); - using type = segmented_radix_sort_config, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - true>; +template +constexpr segmented_radix_sort_config_params segmented_radix_sort_config_params_base() +{ + constexpr unsigned int radix_bits = 6; + + return segmented_radix_sort_config_params{ + radix_bits, + kernel_config_params{128, 17u}, + warp_sort_config_params{true, 32, 4, 256, 3000, 32, 4, 256}, + true + }; }; } // namespace detail @@ -654,8 +637,6 @@ template struct transform_config : public detail::transform_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::transform_config_tag; #ifndef DOXYGEN_SHOULD_SKIP_THIS /// \brief Number of threads in a block. @@ -689,8 +670,6 @@ template struct transform_pointer_config : public detail::transform_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::transform_config_tag; #ifndef DOXYGEN_SHOULD_SKIP_THIS /// \brief Number of threads in a block. @@ -718,29 +697,26 @@ namespace detail { template -struct default_transform_pointer_config_base +constexpr transform_config_params transform_pointer_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(uint128_t), sizeof(Value)); - using type = transform_config<256, item_scale>; -}; + return transform_config_params{ + {256, item_scale} + }; +} template -struct default_transform_config_base +constexpr transform_config_params transform_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); - using type = transform_config<256, ::rocprim::max(1u, 16u / item_scale)>; -}; - -struct binary_search_config_tag : public transform_config_tag -{}; -struct upper_bound_config_tag : public transform_config_tag -{}; -struct lower_bound_config_tag : public transform_config_tag -{}; + return transform_config_params{ + {256, ::rocprim::max(1u, 16u / item_scale)} + }; +} } // namespace detail @@ -752,10 +728,7 @@ template struct binary_search_config : transform_config -{ - /// \brief Identifies the algorithm associated to the config. - using tag = detail::binary_search_config_tag; -}; +{}; /// \brief Configuration for the device-level upper bound operation. /// \tparam BlockSize Number of threads in a block. @@ -765,10 +738,7 @@ template struct upper_bound_config : transform_config -{ - /// \brief Identifies the algorithm associated to the config. - using tag = detail::upper_bound_config_tag; -}; +{}; /// \brief Configuration for the device-level lower bound operation. /// \tparam BlockSize Number of threads in a block. @@ -778,23 +748,18 @@ template struct lower_bound_config : transform_config -{ - /// \brief Identifies the algorithm associated to the config. - using tag = detail::lower_bound_config_tag; -}; +{}; namespace detail { -struct histogram_config_tag -{}; - template -struct default_binary_search_config_base - : binary_search_config< - limit_block_size<256U, sizeof(Value) + sizeof(Output), ROCPRIM_WARP_SIZE_64>::value, - 1> -{}; +constexpr transform_config_params binary_search_config_params_base() +{ + return transform_config_params{ + {limit_block_size<256U, sizeof(Value) + sizeof(Output), ROCPRIM_WARP_SIZE_64>::value, 1} + }; +} /// \brief Provides the kernel parameters for histogram_even, multi_histogram_even, /// histogram_range, and multi_histogram_range based on autotuned configurations or @@ -807,7 +772,7 @@ struct histogram_config_params unsigned int shared_impl_max_bins = 0; unsigned int shared_impl_histograms = 0; - kernel_config_params histogram_global_config = {0, 0}; + kernel_config_params histogram_global_config = histogram_config; }; } // namespace detail @@ -830,8 +795,6 @@ template struct histogram_config : detail::histogram_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::histogram_config_tag; #ifndef DOXYGEN_SHOULD_SKIP_THIS using histogram = HistogramConfig; @@ -852,23 +815,21 @@ namespace detail { template -struct default_histogram_config_base +constexpr histogram_config_params histogram_config_params_base() { - static constexpr unsigned int item_scale - = ::rocprim::detail::ceiling_div(sizeof(Sample), sizeof(int)); + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Sample), sizeof(int)); - using type - = histogram_config>; -}; + constexpr kernel_config_params kernel_params + = {256, ::rocprim::max(8u / Channels / item_scale, 1u)}; -struct adjacent_difference_config_tag -{}; + return histogram_config_params{kernel_params, 1024, 2048, 3, kernel_params}; +}; struct adjacent_difference_config_params { - kernel_config_params kernel_config; - ::rocprim::block_load_method block_load_method; - ::rocprim::block_store_method block_store_method; + kernel_config_params kernel_config{}; + ::rocprim::block_load_method block_load_method = block_load_method::block_load_transpose; + ::rocprim::block_store_method block_store_method = block_store_method::block_store_transpose; }; } // namespace detail @@ -886,8 +847,6 @@ template struct adjacent_difference_config : public detail::adjacent_difference_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::adjacent_difference_config_tag; #ifndef DOXYGEN_SHOULD_SKIP_THIS static constexpr ::rocprim::block_load_method block_load_method = BlockLoadMethod; static constexpr ::rocprim::block_store_method block_store_method = BlockStoreMethod; @@ -907,16 +866,17 @@ namespace detail { template -struct default_adjacent_difference_config_base +constexpr adjacent_difference_config_params adjacent_difference_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); - using type = adjacent_difference_config< - limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, - ::rocprim::max(1u, 16u / item_scale), + return adjacent_difference_config_params{ + {limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale)}, ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose>; + ::rocprim::block_store_method::block_store_transpose + }; }; } // namespace detail @@ -924,25 +884,22 @@ struct default_adjacent_difference_config_base namespace detail { -struct batch_memcpy_config_tag -{}; - struct batch_memcpy_config_params { /// \brief Kernel config for thread- and warp-level copy - kernel_config_params non_blev_batch_memcpy_kernel_config; + kernel_config_params non_blev_batch_memcpy_kernel_config{}; /// \brief Number of bytes per thread for thread-level copy - unsigned int tlev_items_per_thread; + unsigned int tlev_items_per_thread = 0; /// \brief Kernel config for block-level copy - kernel_config_params blev_batch_memcpy_kernel_config; + kernel_config_params blev_batch_memcpy_kernel_config{}; /// \brief Minimum size to use warp-level copy instead of thread-level - unsigned int wlev_size_threshold; + unsigned int wlev_size_threshold = 0; /// \brief Minimum size to use block-level copy instead of warp-level - unsigned int blev_size_threshold; + unsigned int blev_size_threshold = 0; }; } // namespace detail @@ -963,8 +920,6 @@ template struct batch_memcpy_config : public detail::batch_memcpy_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::batch_memcpy_config_tag; #ifndef DOXYGEN_SHOULD_SKIP_THIS static constexpr unsigned int non_blev_block_size = NonBlevBlockSize; static constexpr unsigned int non_blev_items_per_thread = NonBlevItemsPerThread; @@ -989,19 +944,20 @@ namespace detail { template -struct default_batch_memcpy_config_base +constexpr batch_memcpy_config_params batch_memcpy_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); - using type - = batch_memcpy_config::value, - ::rocprim::max(1u, 16u / item_scale), - ::rocprim::max(1u, 16u / item_scale), - limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, - ::rocprim::max(1u, 16u / item_scale), - 128, - 1024>; + return batch_memcpy_config_params{ + {limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale)}, + ::rocprim::max(1u, 16u / item_scale), + {limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale)}, + 128, + 1024 + }; }; } // namespace detail @@ -1040,11 +996,11 @@ namespace detail struct partition_config_params { - kernel_config_params kernel_config; - block_load_method key_block_load_method; - block_load_method value_block_load_method; - block_load_method flag_block_load_method; - block_scan_algorithm block_scan_method; + kernel_config_params kernel_config = {0, 0}; + block_load_method key_block_load_method = ::rocprim::block_load_method::block_load_transpose; + block_load_method value_block_load_method = ::rocprim::block_load_method::block_load_transpose; + block_load_method flag_block_load_method = ::rocprim::block_load_method::block_load_transpose; + block_scan_algorithm block_scan_method = ::rocprim::block_scan_algorithm::using_warp_scan; }; } // namespace detail @@ -1101,35 +1057,33 @@ struct select_config : public detail::partition_config_params namespace detail { -template -struct default_partition_config_base +template +constexpr partition_config_params partition_config_params_base() { - static constexpr unsigned int item_scale + constexpr int ItemScaleBase = 13; + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); using offset_t = std::conditional_t; // Additional shared memory is required by the lookback scan state. - static constexpr unsigned int shared_mem_offset = sizeof( + constexpr unsigned int shared_mem_offset = sizeof( typename offset_lookback_scan_prefix_op>::storage_type); - using type = select_config< - fallback_block_size<256U, sizeof(Key), ROCPRIM_WARP_SIZE_64, shared_mem_offset>::value, - ::rocprim::max(1u, ItemScaleBase / item_scale), - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan>; + return partition_config_params{ + {fallback_block_size<256U, sizeof(Key), ROCPRIM_WARP_SIZE_64, shared_mem_offset>::value, + ::rocprim::max(1u, ItemScaleBase / item_scale)} + }; }; struct reduce_by_key_config_params { kernel_config_params kernel_config; - unsigned int tiles_per_block; block_load_method load_keys_method; block_load_method load_values_method; block_scan_algorithm scan_algorithm; + unsigned int tiles_per_block = 1; }; } // namespace detail @@ -1178,10 +1132,10 @@ struct reduce_by_key_config : public detail::reduce_by_key_config_params constexpr reduce_by_key_config() : detail::reduce_by_key_config_params{ {BlockSize, ItemsPerThread, SizeLimit}, - TilesPerBlock, LoadKeysMethod, LoadValuesMethod, - ScanAlgorithm + ScanAlgorithm, + TilesPerBlock } {}; #endif }; @@ -1190,32 +1144,36 @@ namespace detail { template -struct default_reduce_by_key_config_base +constexpr reduce_by_key_config_params reduce_by_key_config_params_base() { - using small_config = reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - sizeof(Value) < 16 ? 1 : 2>; + constexpr auto small_config = reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + sizeof(Value) < 16 ? 1 : 2, + }; + + if constexpr(std::max(sizeof(Key), sizeof(Value)) <= 16) + { + return small_config; + } - static constexpr unsigned int size_memory_per_item = std::max(sizeof(Key), sizeof(Value)); - static constexpr unsigned int item_scale + constexpr unsigned int size_memory_per_item = std::max(sizeof(Key), sizeof(Value)); + constexpr unsigned int item_scale = static_cast(ceiling_div(size_memory_per_item, 2 * sizeof(int))); - static constexpr unsigned int items_per_thread = std::max(1u, 15u / item_scale); - - using large_config - = reduce_by_key_config::value, - items_per_thread, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2>; - - using type = std:: - conditional_t; + constexpr unsigned int items_per_thread = std::max(1u, 15u / item_scale); + + constexpr auto large_config = reduce_by_key_config_params{ + {limit_block_size<256U, items_per_thread * size_memory_per_item, ROCPRIM_WARP_SIZE_64>:: + value, items_per_thread}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + + return large_config; }; } // namespace detail @@ -1261,8 +1219,6 @@ struct nth_element_config : public detail::nth_element_config_params namespace detail { -struct non_trivial_runs_config_tag -{}; struct non_trivial_runs_config_params { @@ -1287,8 +1243,6 @@ template struct non_trivial_runs_config : public detail::non_trivial_runs_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::non_trivial_runs_config_tag; #ifndef DOXYGEN_DOCUMENTATION_BUILD /// \brief Number of threads in a block. static constexpr unsigned int block_size = BlockSize; @@ -1311,35 +1265,41 @@ namespace detail { template -struct default_non_trivial_runs_config_base +constexpr non_trivial_runs_config_params non_trivial_runs_config_params_base() { - static constexpr unsigned int items_per_thread = 16; - using small_config = non_trivial_runs_config<256U, - items_per_thread, - block_load_method::block_load_vectorize, - block_scan_algorithm::reduce_then_scan>; + constexpr unsigned int items_per_thread = 16; + constexpr auto small_config = non_trivial_runs_config_params{ + {256U, items_per_thread}, + block_load_method::block_load_vectorize, + block_scan_algorithm::reduce_then_scan + }; + + if constexpr(sizeof(InputT) < 8) + { + return small_config; + } using OffsetCountPairT = ::rocprim::tuple; - static constexpr unsigned int size_memory_per_item + constexpr unsigned int size_memory_per_item = std::max(sizeof(InputT), sizeof(OffsetCountPairT)); // Additional shared memory is required by the lookback scan state. - static constexpr unsigned int shared_mem_offset + constexpr unsigned int shared_mem_offset = sizeof(typename offset_lookback_scan_prefix_op< OffsetCountPairT, lookback_scan_state>::storage_type); - using big_config - = non_trivial_runs_config::value, - items_per_thread, - block_load_method::block_load_warp_transpose, - block_scan_algorithm::reduce_then_scan>; + constexpr auto big_config = non_trivial_runs_config_params{ + {detail::limit_block_size<64U, + items_per_thread * size_memory_per_item, + ROCPRIM_WARP_SIZE_64, shared_mem_offset>::value, + items_per_thread}, + block_load_method::block_load_warp_transpose, + block_scan_algorithm::reduce_then_scan + }; - using type = std::conditional_t; + return big_config; }; struct find_first_of_config_params @@ -1347,12 +1307,9 @@ struct find_first_of_config_params kernel_config_params kernel_config{}; }; -struct adjacent_find_config_tag -{}; - struct adjacent_find_config_params { - kernel_config_params kernel_config; + kernel_config_params kernel_config{}; }; } // namespace detail @@ -1380,8 +1337,6 @@ struct find_first_of_config : public detail::find_first_of_config_params template struct adjacent_find_config : public detail::adjacent_find_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::adjacent_find_config_tag; #ifndef DOXYGEN_DOCUMENTATION_BUILD constexpr adjacent_find_config() : detail::adjacent_find_config_params{ @@ -1395,23 +1350,26 @@ namespace detail { template -struct default_find_first_of_config_base +constexpr find_first_of_config_params find_first_of_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); - using type = find_first_of_config<256, ::rocprim::max(1u, 16u / item_scale)>; + return find_first_of_config_params{ + {256, ::rocprim::max(1u, 16u / item_scale)} + }; }; template -struct default_adjacent_find_config_base +constexpr adjacent_find_config_params adjacent_find_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(InputT), sizeof(int)); - using type - = adjacent_find_config::value, - ::rocprim::max(1u, 16u / item_scale)>; + return adjacent_find_config_params{ + {limit_block_size<1024U, sizeof(InputT), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale)} + }; }; } // namespace detail @@ -1473,15 +1431,16 @@ struct search_n_config : public detail::search_n_config_params namespace detail { template -struct default_search_n_config_base +constexpr search_n_config_params search_n_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(InputType), sizeof(int)); - using type - = search_n_config::value, - ::rocprim::max(1u, 10u / item_scale), - 8>; + return search_n_config_params{ + {limit_block_size<256u, sizeof(InputType), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 10u / item_scale)}, + 8 + }; }; } // namespace detail @@ -1523,30 +1482,45 @@ namespace detail { template -struct default_merge_config_base +constexpr merge_config_params merge_config_params_base() { - static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div( - ::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); - - using type = merge_config::value, - ::rocprim::max(1u, 10u / item_scale)>; -}; - -template -struct default_merge_config_base -{ - static constexpr unsigned int item_scale - = ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); - - using type = select_type< - select_type_case>, - select_type_case>, - select_type_case>, - merge_config::value, - ::rocprim::max(1u, 10u / item_scale)>>; -}; + if constexpr(std::is_same_v) + { + constexpr unsigned int item_scale + = ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); + + if constexpr(sizeof(Key) <= 2) + return merge_config_params{ + {256, 11} + }; + else if constexpr(sizeof(Key) <= 4) + return merge_config_params{ + {256, 10} + }; + else if constexpr(sizeof(Key) <= 8) + return merge_config_params{ + {256, 7} + }; + else + return merge_config_params{ + {fallback_block_size<256u, sizeof(Key), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 10u / item_scale)} + }; + } + else + { + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div( + ::rocprim::max(sizeof(Key), sizeof(Value)), + sizeof(int)); + + return merge_config_params{ + {fallback_block_size<256u, + ::rocprim::max(sizeof(Key), sizeof(Value)), + ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 10u / item_scale)} + }; + } +} } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_nth_element.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_nth_element.hpp index d5a4bf8f3f3..0b6173127e7 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_nth_element.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_nth_element.hpp @@ -542,13 +542,14 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void } template ROCPRIM_INLINE hipError_t - nth_element_keys_impl(detail::target_arch target_arch, + nth_element_keys_impl(target current_target, KeysIterator keys, typename std::iterator_traits::value_type* keys_buffer, typename std::iterator_traits::value_type* tree, @@ -620,12 +621,12 @@ hipError_t size, compare_function); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - find_splitters_kernel, - 1, - num_splitters, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + find_splitters_kernel, + 1, + num_splitters, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("find_splitters_kernel", size, start); start_timer(); @@ -638,12 +639,12 @@ hipError_t equality_buckets, compare_function); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - count_bucket_sizes_kernel, - num_blocks, - num_threads_per_block, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + count_bucket_sizes_kernel, + num_blocks, + num_threads_per_block, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("count_bucket_sizes_kernel", size, start); start_timer(); @@ -654,12 +655,13 @@ hipError_t equality_buckets, rank); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - find_nth_element_bucket_kernel, - 1, - num_buckets, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan(current_target, + find_nth_element_bucket_kernel, + 1, + num_buckets, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("find_nth_element_bucket_kernel", size, start); start_timer(); @@ -675,12 +677,12 @@ hipError_t compare_function, ordered_bid); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - copy_buckets_kernel, - num_blocks, - num_threads_per_block, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + copy_buckets_kernel, + num_blocks, + num_threads_per_block, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("copy_buckets_kernel", size, start); // Copy the results in keys_buffer back to the keys @@ -720,12 +722,12 @@ hipError_t auto block_sort_kernel = [=](auto arch_config) { block_sort_kernel_impl(keys, size, compare_function); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - block_sort_kernel, - 1, - stop_recursion_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + block_sort_kernel, + 1, + stop_recursion_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("kernel_block_sort", size, start); return hipSuccess; } diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_radix_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_radix_sort.hpp index 36c570b92fe..0bf41a523fa 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_radix_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_radix_sort.hpp @@ -1276,7 +1276,7 @@ struct onesweep_iteration_helper for(unsigned int i = 0; i < ItemsPerThread; ++i) { // It only seems worse on gfx942 in some cases. - if ROCPRIM_AMDGCN_CONSTEXPR(IS_CDNA3()) + if ROCPRIM_AMDGCN_CONSTEXPR(ROCPRIM_IS_CDNA3()) { const int offset = ranks[i] - x; if(offset >= 0 && offset < static_cast(BlockSize * NKey)) @@ -1369,7 +1369,7 @@ struct onesweep_iteration_helper for(unsigned int i = 0; i < ItemsPerThread; ++i) { // It only seems worse on gfx942 in some cases. - if ROCPRIM_AMDGCN_CONSTEXPR(IS_CDNA3()) + if ROCPRIM_AMDGCN_CONSTEXPR(ROCPRIM_IS_CDNA3()) { const int offset = ranks[i] - x; if(offset >= 0 && offset < static_cast(BlockSize * NValue)) diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_search.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_search.hpp index 50c4d435577..89d8ad6478a 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_search.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_search.hpp @@ -290,13 +290,15 @@ hipError_t search_impl(void* temporary_storage, using key_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; - using config = wrapped_search_config; + using selector = search_config_selector; target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + const target current_target(target_arch, target_gpu); - const search_config_params params = dispatch_target_arch(target_arch); - + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -349,12 +351,12 @@ hipError_t search_impl(void* temporary_storage, keys_size, compare_function); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - search_shared_kernel, - num_blocks, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + search_shared_kernel, + num_blocks, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_kernel_shared", size, start); } else @@ -370,12 +372,12 @@ hipError_t search_impl(void* temporary_storage, keys_size, compare_function); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - search_shared_kernel, - num_blocks, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + search_shared_kernel, + num_blocks, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_kernel_shared", size, start); } } @@ -393,12 +395,12 @@ hipError_t search_impl(void* temporary_storage, keys_size, compare_function); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - search_kernel, - num_blocks, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + search_kernel, + num_blocks, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_kernel", size, start); } else @@ -414,12 +416,12 @@ hipError_t search_impl(void* temporary_storage, keys_size, compare_function); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - search_kernel, - num_blocks, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + search_kernel, + num_blocks, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_kernel", size, start); } } diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_search_n.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_search_n.hpp index f41c546c9ad..4e7c168ec01 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_search_n.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_search_n.hpp @@ -65,7 +65,8 @@ hipError_t search_n_impl(void* temporary_storage, { using input_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; - using config = wrapped_search_n_config; + + using Selector = search_n_config_selector; // The `size` must greater than or equal to `count` if(count > size) @@ -75,8 +76,12 @@ hipError_t search_n_impl(void* temporary_storage, target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); - const auto params = dispatch_target_arch(target_arch); + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -164,12 +169,12 @@ hipError_t search_n_impl(void* temporary_storage, } } }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - search_n_normal_kernel, - num_blocks, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + search_n_normal_kernel, + num_blocks, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_n_normal_kernel", size, start); ROCPRIM_RETURN_ON_ERROR( transform(tmp_output, output, 1, identity(), stream, debug_synchronous)); @@ -244,12 +249,12 @@ hipError_t search_n_impl(void* temporary_storage, } } }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - search_n_find_heads_kernel, - num_blocks_for_find_heads, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + search_n_find_heads_kernel, + num_blocks_for_find_heads, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_n_find_heads_kernel", possible_head_exist_size, start); @@ -294,12 +299,12 @@ hipError_t search_n_impl(void* temporary_storage, filtered_heads[atomic_add(tmp_output, 1)] = this_head; } }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - search_n_heads_filter_kernel, - num_blocks_for_heads_filter, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + search_n_heads_filter_kernel, + num_blocks_for_heads_filter, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_n_heads_filter_kernel", num_groups, start); size_t h_filtered_heads_size = 0; @@ -387,12 +392,12 @@ hipError_t search_n_impl(void* temporary_storage, } } }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - search_n_discard_heads_kernel, - num_blocks_for_discard_heads, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + search_n_discard_heads_kernel, + num_blocks_for_discard_heads, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_n_discard_heads_kernel ", h_filtered_heads_size, start); diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_difference.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_difference.hpp index b8ebb7b961a..91d18c5db80 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_difference.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_difference.hpp @@ -77,17 +77,17 @@ hipError_t adjacent_difference_impl(void* const temporary_storage, using larger_type = std::conditional_t<(sizeof(value_type) >= sizeof(output_type)), value_type, output_type>; - using config = wrapped_adjacent_difference_config; + using Selector = adjacent_difference_config_selector; detail::target_arch target_arch; - hipError_t result = detail::host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } + ROCPRIM_RETURN_ON_ERROR(detail::host_target_arch(stream, target_arch)); + + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); - const detail::adjacent_difference_config_params params - = detail::dispatch_target_arch(target_arch); + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -183,12 +183,12 @@ hipError_t adjacent_difference_impl(void* const temporary_storage, starting_block); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - adjacent_difference_kernel, - current_blocks, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + adjacent_difference_kernel, + current_blocks, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("adjacent_difference_kernel", current_size, diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_difference_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_difference_config.hpp index e68391dd03f..9e06b2631a5 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_difference_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_difference_config.hpp @@ -43,63 +43,34 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// Specialization for user provided configuration -template -struct wrapped_adjacent_difference_config +template +struct adjacent_difference_config_selector { - static_assert( - std::is_same::value, - "Config must be a specialization of struct template adjacent_difference_config"); + using targets = std:: + conditional_t; + using param_type = adjacent_difference_config_params; - template - struct architecture_config - { - static constexpr adjacent_difference_config_params params = AdjacentDifferenceConfig{}; - }; -}; - -// Specialization for selecting the default configuration for in place -template -struct wrapped_adjacent_difference_config -{ - template - struct architecture_config - { - static constexpr adjacent_difference_config_params params - = default_adjacent_difference_inplace_config(Arch), Value>{}; - }; -}; + param_type params; -// Specialization for selecting the default configuration for out of place -template -struct wrapped_adjacent_difference_config -{ - template - struct architecture_config + template + constexpr param_type picker_helper() { - static constexpr adjacent_difference_config_params params - = default_adjacent_difference_config(Arch), Value>{}; - }; + // Different configs if it is inplace. + if constexpr(InPlace) + { + return adjacent_difference_inplace_config_picker(); + } + else + { + return adjacent_difference_config_picker(); + } + } + + template + constexpr adjacent_difference_config_selector(Target) : params(picker_helper()) + {} }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr adjacent_difference_config_params - wrapped_adjacent_difference_config:: - architecture_config::params; -template -template -constexpr adjacent_difference_config_params - wrapped_adjacent_difference_config::architecture_config< - Arch>::params; -template -template -constexpr adjacent_difference_config_params - wrapped_adjacent_difference_config::architecture_config< - Arch>::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_find.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_find.hpp index 65f74291359..539696ec499 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_find.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_find.hpp @@ -69,8 +69,7 @@ hipError_t adjacent_find_impl(void* const temporary_storage, // Use dynamic tile id using ordered_tile_id_type = detail::ordered_block_id; - // Kernel launch config - using config = wrapped_adjacent_find_config; + using Selector = adjacent_find_config_selector; // Transform input auto wrapped_equal_op = [op, size](const wrapped_input_type& a) -> index_type @@ -138,7 +137,14 @@ hipError_t adjacent_find_impl(void* const temporary_storage, target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - const adjacent_find_config_params params = dispatch_target_arch(target_arch); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); + const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -155,7 +161,8 @@ hipError_t adjacent_find_impl(void* const temporary_storage, ordered_tile_id); }; - auto adjacent_find_block_reduce_kernel = make_launch_plan(target_arch, kernel); + auto adjacent_find_block_reduce_kernel + = make_launch_plan(current_target, kernel); // Get grid size for maximum occupancy, as we may not be able to schedule all the blocks // at the same time diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_find_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_find_config.hpp index 5a2990849c1..6e31a488448 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_find_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_find_config.hpp @@ -34,67 +34,33 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -struct wrapped_adjacent_find_config +template +struct adjacent_find_config_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template adjacent_find_config"); + using targets = adjacent_find_targets; + using param_type = adjacent_find_config_params; - template - struct architecture_config - { - static constexpr adjacent_find_config_params params = Config{}; - }; -}; + param_type params; -// Generic for default config: instantiate base config. -template -struct wrapped_adjacent_find_impl -{ - template - struct architecture_config + template + constexpr param_type picker_helper() { - static constexpr adjacent_find_config_params params = - typename default_adjacent_find_config_base::type{}; - }; + // Different configs if it is inplace. + if constexpr(rocprim::is_arithmetic::value) + { + return adjacent_find_config_picker(); + } + else + { + return adjacent_find_config_params_base(); + } + } + + template + constexpr adjacent_find_config_selector(Target) : params(picker_helper()) + {} }; -// Specialization for default config if types are arithmetic or half/bfloat16-precision -// floating point types: instantiate the tuned config. -template -struct wrapped_adjacent_find_impl::value>> -{ - template - struct architecture_config - { - static constexpr adjacent_find_config_params params - = default_adjacent_find_config(Arch), Type>(); - }; -}; - -// Specialization for default config. -template -struct wrapped_adjacent_find_config : wrapped_adjacent_find_impl -{}; - -#ifndef DOXYGEN_DOCUMENTATION_BUILD -template -template -constexpr adjacent_find_config_params - wrapped_adjacent_find_config::architecture_config::params; - -template -template -constexpr adjacent_find_config_params - wrapped_adjacent_find_impl::architecture_config::params; - -template -template -constexpr adjacent_find_config_params wrapped_adjacent_find_impl< - Type, - std::enable_if_t::value>>::architecture_config::params; -#endif // DOXYGEN_DOCUMENTATION_BUILD - } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_binary_search.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_binary_search.hpp index e4dd5f2218e..3a00502eef8 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_binary_search.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_binary_search.hpp @@ -39,26 +39,24 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class Config, - class HaystackIterator, - class NeedlesIterator, - class OutputIterator, - class SearchFunction, - class CompareFunction -> -inline -hipError_t binary_search(void * temporary_storage, - size_t& storage_size, - HaystackIterator haystack, - NeedlesIterator needles, - OutputIterator output, - size_t haystack_size, - size_t needles_size, - SearchFunction search_op, - CompareFunction compare_op, - hipStream_t stream, - bool debug_synchronous) +template +inline hipError_t binary_search(void* temporary_storage, + size_t& storage_size, + HaystackIterator haystack, + NeedlesIterator needles, + OutputIterator output, + size_t haystack_size, + size_t needles_size, + SearchFunction search_op, + CompareFunction compare_op, + hipStream_t stream, + bool debug_synchronous) { using value_type = typename std::iterator_traits::value_type; @@ -70,7 +68,9 @@ hipError_t binary_search(void * temporary_storage, return hipSuccess; } - return detail::transform_impl( + constexpr bool is_pointer + = false; // We do not use the optimization for transform when input is a pointer. + return detail::transform_impl( needles, output, needles_size, @@ -209,27 +209,21 @@ hipError_t lower_bound(void * temporary_storage, hipStream_t stream = 0, bool debug_synchronous = false) { - static_assert(detail::is_default_or_has_tag::value, - "Config must be a specialization of struct template lower_bound_config"); - using value_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; - using config - = std::conditional_t::value, - detail::default_config_for_lower_bound, - Config>; + using selector = detail::lower_bound_config_selector; - return detail::binary_search(temporary_storage, - storage_size, - haystack, - needles, - output, - haystack_size, - needles_size, - detail::lower_bound_search_op(), - compare_op, - stream, - debug_synchronous); + return detail::binary_search(temporary_storage, + storage_size, + haystack, + needles, + output, + haystack_size, + needles_size, + detail::lower_bound_search_op(), + compare_op, + stream, + debug_synchronous); } /// \brief Parallel primitive that uses binary search for computing an upper bound on a given ordered @@ -346,26 +340,21 @@ hipError_t upper_bound(void * temporary_storage, hipStream_t stream = 0, bool debug_synchronous = false) { - static_assert(detail::is_default_or_has_tag::value, - "Config must be a specialization of struct template upper_bound_config"); using value_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; - using config - = std::conditional_t::value, - detail::default_config_for_upper_bound, - Config>; + using selector = detail::upper_bound_config_selector; - return detail::binary_search(temporary_storage, - storage_size, - haystack, - needles, - output, - haystack_size, - needles_size, - detail::upper_bound_search_op(), - compare_op, - stream, - debug_synchronous); + return detail::binary_search(temporary_storage, + storage_size, + haystack, + needles, + output, + haystack_size, + needles_size, + detail::upper_bound_search_op(), + compare_op, + stream, + debug_synchronous); } /// \brief Parallel primitive for performing a binary search (on a sorted range) of a given input. @@ -477,26 +466,21 @@ hipError_t binary_search(void * temporary_storage, hipStream_t stream = 0, bool debug_synchronous = false) { - static_assert(detail::is_default_or_has_tag::value, - "Config must be a specialization of struct template binary_search_config"); using value_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; - using config - = std::conditional_t::value, - detail::default_config_for_binary_search, - Config>; + using selector = detail::binary_search_config_selector; - return detail::binary_search(temporary_storage, - storage_size, - haystack, - needles, - output, - haystack_size, - needles_size, - detail::binary_search_op(), - compare_op, - stream, - debug_synchronous); + return detail::binary_search(temporary_storage, + storage_size, + haystack, + needles, + output, + haystack_size, + needles_size, + detail::binary_search_op(), + compare_op, + stream, + debug_synchronous); } END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_binary_search_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_binary_search_config.hpp index 3580830b88d..79c038ae52e 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_binary_search_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_binary_search_config.hpp @@ -39,67 +39,46 @@ namespace detail { template -struct default_config_for_binary_search -{}; - -template -struct default_config_for_upper_bound -{}; +struct binary_search_config_selector +{ + using targets = binary_search_targets; + using param_type = transform_config_params; -template -struct default_config_for_lower_bound -{}; + param_type params; -template -struct wrapped_transform_config, Unused, IsPointer> -{ - template - struct architecture_config - { - static constexpr transform_config_params params - = default_binary_search_config(Arch), Value, Output>{}; - }; + template + constexpr binary_search_config_selector(Target) + : params(binary_search_config_picker()) + {} }; -template -struct wrapped_transform_config, Unused, IsPointer> +template +struct lower_bound_config_selector { - template - struct architecture_config - { - static constexpr transform_config_params params - = default_upper_bound_config(Arch), Value, Output>{}; - }; + using targets = lower_bound_targets; + using param_type = transform_config_params; + + param_type params; + + template + constexpr lower_bound_config_selector(Target) + : params(lower_bound_config_picker()) + {} }; -template -struct wrapped_transform_config, Unused, IsPointer> +template +struct upper_bound_config_selector { - template - struct architecture_config - { - static constexpr transform_config_params params - = default_lower_bound_config(Arch), Value, Output>{}; - }; -}; + using targets = upper_bound_targets; + using param_type = transform_config_params; + + param_type params; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr transform_config_params - wrapped_transform_config, Unused, IsPointer>:: - architecture_config::params; -template -template -constexpr transform_config_params - wrapped_transform_config, Unused, IsPointer>:: - architecture_config::params; -template -template -constexpr transform_config_params - wrapped_transform_config, Unused, IsPointer>:: - architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS + template + constexpr upper_bound_config_selector(Target) + : params(upper_bound_config_picker()) + {} +}; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_find_first_of.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_find_first_of.hpp index 6db6fa91388..fd64363a784 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_find_first_of.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_find_first_of.hpp @@ -183,17 +183,19 @@ hipError_t find_first_of_impl(void* temporary_storage, bool debug_synchronous) { using type = typename std::iterator_traits::value_type; - using config = wrapped_find_first_of_config; + using Selector = find_first_of_config_selector; using find_first_of_kernels = find_first_of_impl_kernels; target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const find_first_of_config_params params = dispatch_target_arch(target_arch); + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -207,7 +209,7 @@ hipError_t find_first_of_impl(void* temporary_storage, ordered_bid_type::id_type* ordered_bid_storage = nullptr; // Calculate required temporary storage - result = temp_storage::partition( + hipError_t result = temp_storage::partition( temporary_storage, storage_size, temp_storage::make_linear_partition( @@ -246,7 +248,8 @@ hipError_t find_first_of_impl(void* temporary_storage, compare_function); }; - auto find_first_of_configured_kernel = make_launch_plan(target_arch, kernel); + auto find_first_of_configured_kernel + = make_launch_plan(current_target, kernel); const size_t shared_memory_size = 0; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_find_first_of_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_find_first_of_config.hpp index 26c0a30d7a4..610389f2118 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_find_first_of_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_find_first_of_config.hpp @@ -33,40 +33,19 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// generic struct that instantiates custom configurations -template -struct wrapped_find_first_of_config +template +struct find_first_of_config_selector { - template - struct architecture_config - { - static constexpr find_first_of_config_params params = Config{}; - }; -}; + using targets = find_first_of_targets; + using param_type = find_first_of_config_params; -// specialized for rocprim::default_config, which instantiates the default_find_first_of_config -template -struct wrapped_find_first_of_config -{ - template - struct architecture_config - { - static constexpr find_first_of_config_params params - = default_find_first_of_config(Arch), Type>(); - }; -}; + param_type params; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr find_first_of_config_params - wrapped_find_first_of_config::architecture_config::params; - -template -template -constexpr find_first_of_config_params - wrapped_find_first_of_config::architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS + template + constexpr find_first_of_config_selector(Target) + : params(find_first_of_config_picker()) + {} +}; } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_histogram.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_histogram.hpp index f102b51b0e2..b4ac6e6b4ed 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_histogram.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_histogram.hpp @@ -124,63 +124,6 @@ struct HistogramSharedOp } }; -template -struct histogram_launch_plan -{ - using kernel_type = void (*)(Kernel); - - kernel_type kernel; - Kernel device_callback; - - unsigned int shared_impl_histograms = 0; - unsigned int max_grid_size = 0; - - void launch(dim3 grid, dim3 block, size_t shmem, hipStream_t stream) const - { - hipLaunchKernelGGL(HIP_KERNEL_NAME(kernel), grid, block, shmem, stream, device_callback); - } -}; - -template - class LaunchSelector> -auto make_histogram_launch_plan(rocprim::detail::target_arch arch, Kernel kernel) - -> histogram_launch_plan -{ - histogram_launch_plan plan{nullptr, std::move(kernel), 0u, 0u}; - - bool found = false; - - for_each_arch( - [&](auto arch_tag) - { - constexpr auto Arch = decltype(arch_tag)::value; - if(Arch != arch || found) - return; - - plan.kernel = trampoline_kernel; - - constexpr auto params = Config::template architecture_config::params; - - plan.shared_impl_histograms = params.shared_impl_histograms; - plan.max_grid_size = params.max_grid_size; - - found = true; - }); - if(!found) - { - constexpr auto Arch = rocprim::detail::target_arch::unknown; - - plan.kernel = trampoline_kernel; - - constexpr auto params = Config::template architecture_config::params; - plan.shared_impl_histograms = params.shared_impl_histograms; - plan.max_grid_size = params.max_grid_size; - } - return plan; -} - template::value_type; - - using config = wrapped_histogram_config; + using selector = histogram_config_selector; detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const histogram_config_params params = dispatch_target_arch(target_arch); + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.histogram_config.block_size; const unsigned int items_per_thread = params.histogram_config.items_per_thread; const unsigned int shared_impl_max_bins = params.shared_impl_max_bins; @@ -333,14 +275,13 @@ inline hipError_t histogram_impl(void* temporary_storage, init_histogram(hist, bin_counts); }; - ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan( - target_arch, - init_histogram_kernel, - ::rocprim::detail::ceiling_div(max_bins, block_size), - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan( + current_target, + init_histogram_kernel, + ::rocprim::detail::ceiling_div(max_bins, block_size), + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("init_histogram_kernel", max_bins, start); @@ -367,9 +308,9 @@ inline hipError_t histogram_impl(void* temporary_storage, 0, 0}; - auto plan = make_histogram_launch_plan( - target_arch, - op); + auto plan + = make_launch_plan(current_target, + op); const size_t block_histogram_bytes = total_shared_bins * sizeof(unsigned int); @@ -379,7 +320,7 @@ inline hipError_t histogram_impl(void* temporary_storage, // memory usage unsigned int chosen_shared_histograms = 0; int max_blocks_per_mp = 0; - for(unsigned int n = plan.shared_impl_histograms; n >= 1; n--) + for(unsigned int n = params.shared_impl_histograms; n >= 1; n--) { int blocks_per_mp; ROCPRIM_RETURN_ON_ERROR(hipOccupancyMaxActiveBlocksPerMultiprocessor( @@ -466,15 +407,14 @@ inline hipError_t histogram_impl(void* temporary_storage, block_id_count); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan( - target_arch, - histogram_private_global_kernel, - dim3(global_histogram_grid_size), - dim3(params.histogram_global_config.block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan( + current_target, + histogram_private_global_kernel, + dim3(global_histogram_grid_size), + dim3(params.histogram_global_config.block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("histogram_private_global_kernel", blocks_x * block_size * rows, @@ -502,14 +442,13 @@ inline hipError_t histogram_impl(void* temporary_storage, }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan(target_arch, - histogram_global_kernel, - dim3(blocks_x, rows), - dim3(block_size, 1), - 0, - stream)); + execute_launch_plan( + current_target, + histogram_global_kernel, + dim3(blocks_x, rows), + dim3(block_size, 1), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("histogram_global_kernel", blocks_x * block_size * rows, start); diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_histogram_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_histogram_config.hpp index b631e2a4b5b..8dd28c97f52 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_histogram_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_histogram_config.hpp @@ -32,49 +32,33 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -struct wrapped_histogram_config +template +struct histogram_config_static_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template histogram_config"); - - template - struct architecture_config - { - static constexpr histogram_config_params params = HistogramConfig{}; - }; + static constexpr auto block_size + = target_config::params.histogram_config.block_size; }; -template -struct wrapped_histogram_config +template +struct histogram_global_config_static_selector { - template - struct architecture_config - { - static constexpr histogram_config_params params - = default_histogram_config(Arch), - Sample, - Channels, - ActiveChannels>{}; - }; + static constexpr auto block_size + = target_config::params.histogram_global_config.block_size; }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr histogram_config_params - wrapped_histogram_config:: - architecture_config::params; +template +struct histogram_config_selector +{ + using targets = histogram_targets; + using param_type = histogram_config_params; + + param_type params; -template -template -constexpr histogram_config_params - wrapped_histogram_config::architecture_config< - Arch>::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS + template + constexpr histogram_config_selector(Target) + : params(histogram_config_picker()) + {} +}; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_memcpy_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_memcpy_config.hpp index 5e8af2a3a62..6586e1fa197 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_memcpy_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_memcpy_config.hpp @@ -40,47 +40,45 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// Specialization for user provided configuration -template -struct wrapped_batch_memcpy_config +template +struct non_blev_batch_memcpy_config_static_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template batch_memcpy_config"); - - template - struct architecture_config - { - static constexpr batch_memcpy_config_params params = BatchMemcpyConfig{}; - }; + static constexpr auto block_size = target_config::params + .non_blev_batch_memcpy_kernel_config.block_size; }; -// Specialization for selecting the default configuration for out of place -template -struct wrapped_batch_memcpy_config +template +struct blev_batch_memcpy_config_static_selector { - template - struct architecture_config - { - static constexpr batch_memcpy_config_params params - = IsMemCpy ? (batch_memcpy_config_params) - default_batch_memcpy_config(Arch), Value>{} - : (batch_memcpy_config_params) - default_batch_copy_config(Arch), Value>{}; - }; + static constexpr auto block_size = target_config::params + .blev_batch_memcpy_kernel_config.block_size; }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr batch_memcpy_config_params - wrapped_batch_memcpy_config::architecture_config< - Arch>::params; template -template -constexpr batch_memcpy_config_params - wrapped_batch_memcpy_config::architecture_config< - Arch>::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS +struct batch_memcpy_config_selector +{ + using targets = std::conditional_t; + using param_type = batch_memcpy_config_params; + + param_type params; + + template + constexpr param_type picker_helper() + { + if constexpr(IsMemCpy) + { + return batch_memcpy_config_picker(); + } + else + { + return batch_copy_config_picker(); + } + } + + template + constexpr batch_memcpy_config_selector(Target) : params(picker_helper()) + {} +}; } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_merge.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_merge.hpp index fba3fd5c389..4b41bf68fb3 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_merge.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_merge.hpp @@ -42,32 +42,27 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -inline size_t get_merge_vsmem_size_per_block(detail::target_arch arch) +template +inline size_t get_merge_vsmem_size_per_block(detail::target t) { - std::optional vsmem_per_block; - for_each_arch( - [&](auto arch_tag) + using targets = typename Selector::targets; + + size_t vsmem_per_block = 0; + + targets::for_each( + [&](auto candidate) { - constexpr target_arch Arch = decltype(arch_tag)::value; - if(Arch != arch || vsmem_per_block) - return; - using ArchConfig = typename Config::template architecture_config; - using merge_kernel_impl_t = merge_kernel_impl_; - using merge_vsmem_helper_t = detail::vsmem_helper_impl; - - vsmem_per_block = merge_vsmem_helper_t::vsmem_per_block; + if(target{candidate} == most_common_config(t)) + { + using ArchConfig = target_config; + using merge_kernel_impl_t = merge_kernel_impl_; + using merge_vsmem_helper_t = detail::vsmem_helper_impl; + + vsmem_per_block = merge_vsmem_helper_t::vsmem_per_block; + } }); - if(!vsmem_per_block) - { - using ArchConfig = typename Config::template architecture_config; - using merge_kernel_impl_t = merge_kernel_impl_; - using merge_vsmem_helper_t = detail::vsmem_helper_impl; - - vsmem_per_block = merge_vsmem_helper_t::vsmem_per_block; - } - return vsmem_per_block.value(); + return vsmem_per_block; } template::value_type; using value_type = typename std::iterator_traits::value_type; - using config = wrapped_merge_config; + using selector = merge_config_selector; detail::target_arch target_arch; - hipError_t result = detail::host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const merge_config_params params = detail::dispatch_target_arch(target_arch); + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int half_block = block_size / 2; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -115,7 +110,7 @@ inline hipError_t merge_impl(void* temporary_storage, = ((input1_size + input2_size) + items_per_block - 1) / items_per_block; size_t virtual_shared_memory_size - = get_merge_vsmem_size_per_block(target_arch) + = get_merge_vsmem_size_per_block(current_target) * number_of_blocks; unsigned int* index = nullptr; @@ -172,12 +167,12 @@ inline hipError_t merge_impl(void* temporary_storage, compare_function); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - partition_kernel, - partition_blocks, - half_block, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + partition_kernel, + partition_blocks, + half_block, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("partition_kernel", input1_size, start); @@ -212,12 +207,12 @@ inline hipError_t merge_impl(void* temporary_storage, storage); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - merge_kernel, - dim3(number_of_blocks), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + merge_kernel, + dim3(number_of_blocks), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("merge_kernel", input1_size, start); diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_merge_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_merge_config.hpp index 5bf72abe710..a585d1ae747 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_merge_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_merge_config.hpp @@ -39,40 +39,18 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// generic struct that instantiates custom configurations -template -struct wrapped_merge_config +template +struct merge_config_selector { - template - struct architecture_config - { - static constexpr merge_config_params params = Config(); - }; -}; + using targets = merge_targets; + using param_type = merge_config_params; -// specialized for rocprim::default_config, which instantiates the default_ALGO_config -template -struct wrapped_merge_config -{ - template - struct architecture_config - { - static constexpr merge_config_params params - = default_merge_config(Arch), KeyType, ValueType>{}; - }; -}; + param_type params; -#ifndef DOXYGEN_DOCUMENTATION_BUILD -template -template -constexpr merge_config_params - wrapped_merge_config::architecture_config::params; - -template -template -constexpr merge_config_params - wrapped_merge_config::architecture_config::params; -#endif // DOXYGEN_DOCUMENTATION_BUILD + template + constexpr merge_config_selector(Target) : params(merge_config_picker()) + {} +}; } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort.hpp index 00b29896068..73759a712f2 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort.hpp @@ -72,17 +72,16 @@ inline hipError_t merge_sort_block_merge_impl( using value_type = typename std::iterator_traits::value_type; constexpr bool with_values = !std::is_same::value; - using config = wrapped_merge_sort_block_merge_config; + using selector = merge_sort_block_merge_config_selector; detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const merge_sort_block_merge_config_params params - = dispatch_target_arch(target_arch); + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); const unsigned int merge_oddeven_block_size = params.merge_oddeven_config.block_size; const unsigned int merge_oddeven_items_per_thread = params.merge_oddeven_config.items_per_thread; @@ -211,10 +210,10 @@ inline hipError_t merge_sort_block_merge_impl( // Note: shared memory is not used in this kernel so there is no need to pass vsmem ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan( - target_arch, + execute_launch_plan( + current_target, device_block_merge_mergepath_partition_kernel, dim3(merge_partition_number_of_blocks), dim3(merge_partition_block_size), @@ -260,10 +259,8 @@ inline hipError_t merge_sort_block_merge_impl( storage); }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan( - target_arch, + execute_launch_plan( + current_target, device_block_merge_mergepath_kernel, calculate_grid_dim(merge_mergepath_number_of_blocks, merge_mergepath_block_size), @@ -299,10 +296,8 @@ inline hipError_t merge_sort_block_merge_impl( compare_function); }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan( - target_arch, + execute_launch_plan( + current_target, device_block_merge_oddeven_kernel, dim3(merge_oddeven_number_of_blocks), dim3(merge_oddeven_block_size), @@ -383,17 +378,16 @@ inline hipError_t merge_sort_block_merge( using value_type = typename std::iterator_traits::value_type; constexpr bool with_values = !std::is_same::value; - using config = wrapped_merge_sort_block_merge_config; + using selector = merge_sort_block_merge_config_selector; detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const merge_sort_block_merge_config_params params - = dispatch_target_arch(target_arch); + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + const auto params = get_config(Config{}, current_target); const unsigned int merge_mergepath_block_size = params.merge_mergepath_config.block_size; const unsigned int merge_mergepath_items_per_thread = params.merge_mergepath_config.items_per_thread; @@ -476,17 +470,16 @@ inline hipError_t merge_sort_block_sort(KeysInputIterator keys_input, using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; - using config = wrapped_merge_sort_block_sort_config; + using selector = merge_sort_block_sort_config_selector; detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const merge_sort_block_sort_config_params params - = dispatch_target_arch(target_arch); + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); sort_items_per_block = params.kernel_config.block_size * params.kernel_config.items_per_thread; const unsigned int sort_number_of_blocks = ceiling_div(size, sort_items_per_block); @@ -546,8 +539,8 @@ inline hipError_t merge_sort_block_sort(KeysInputIterator keys_input, storage); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan( - target_arch, + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan( + current_target, block_sort_kernel, calculate_grid_dim(sort_number_of_blocks, params.kernel_config.block_size), params.kernel_config.block_size, @@ -567,12 +560,17 @@ void TAssertEqualGreater() static_assert(A >= B, "A not greater or equal to B"); }; -template +template ROCPRIM_KERNEL void device_merge_sort_compile_time_verifier_arch() { - using BSArchConfig = typename BlockSortConfig::template architecture_config; - using BMArchConfig = typename BlockMergeConfig::template architecture_config; + using BSArchConfig = target_config; + using BMArchConfig = target_config; static constexpr auto bs_params = BSArchConfig::params; static constexpr auto bm_params = BMArchConfig::params; @@ -602,36 +600,51 @@ void device_merge_sort_compile_time_verifier_arch() "merge_mergepath_items_per_block"); } -template +template inline void device_merge_sort_compile_time_verifier() noexcept { - static const bool once = [] - { - for_each_arch( - [](auto arch_tag) - { - constexpr auto A = decltype(arch_tag)::value; - (void)&device_merge_sort_compile_time_verifier_arch; - }); - (void)&device_merge_sort_compile_time_verifier_arch; - return true; - }(); - (void)once; + // BSTargets and BMTargets can be different so we do not know at compile time + // the combination of configs that will be chosen. + using BSTargets = typename BSSelector::targets; + using BMTargets = typename BMSelector::targets; + + BSTargets::for_each( + [&](auto t) + { + constexpr target ct = most_common_config(target{t}); + (void)device_merge_sort_compile_time_verifier_arch; + }); + + BMTargets::for_each( + [&](auto t) + { + constexpr target ct = most_common_config(target{t}); + (void)device_merge_sort_compile_time_verifier_arch; + }); } -template -inline size_t merge_sort_vsmem_size_for_arch(size_t size) + class BlockSortSelector, + class BlockMergeSelector, + class Key, + class Value> +inline size_t merge_sort_vsmem_size_for_target(size_t size) { - using BSArchConfig = typename BlockSortConfig::template architecture_config; - using BMArchConfig = typename BlockMergeConfig::template architecture_config; + using BSArchConfig = target_config; + using BMArchConfig = target_config; static constexpr auto bs_params = BSArchConfig::params; static constexpr auto bm_params = BMArchConfig::params; @@ -679,35 +692,44 @@ inline size_t merge_sort_vsmem_size_for_arch(size_t size) return virtual_shared_memory_size; } -template -inline size_t get_merge_sort_vsmem_size(detail::target_arch arch, size_t size) noexcept +template +inline size_t get_merge_sort_vsmem_size(detail::target t, size_t size) noexcept { - std::optional out; + using BlockSortTarget = typename BlockSortSelector::targets; + using BlockMergeTarget = typename BlockMergeSelector::targets; - for_each_arch( - [&](auto arch_tag) + size_t vsmem_per_block = 0; + + BlockSortTarget::for_each( + [&](auto BScandidate) { - if(out) - return; - constexpr auto Arch = decltype(arch_tag)::value; - if(Arch != arch) - return; - - out = merge_sort_vsmem_size_for_arch(size); + if(target{BScandidate} == most_common_config(t)) + { + BlockMergeTarget::for_each( + [&](auto BMcandidate) + { + if(target{BMcandidate} == most_common_config(t)) + { + vsmem_per_block + = merge_sort_vsmem_size_for_target(size); + } + }); + } }); - if(!out) - { - out = merge_sort_vsmem_size_for_arch(size); - } - return *out; + + return vsmem_per_block; } template::type; using block_merge_config = typename std:: conditional::type; - using wrapped_bs_config - = wrapped_merge_sort_block_sort_config; - using wrapped_bm_config - = wrapped_merge_sort_block_merge_config; + using selector_bm = merge_sort_block_merge_config_selector; + using selector_bs = merge_sort_block_sort_config_selector; // Some helpful checks during compile-time - device_merge_sort_compile_time_verifier(); + device_merge_sort_compile_time_verifier(); unsigned int sort_items_per_block = 1; // We will get this later from the block_sort algorithm detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const merge_sort_block_merge_config_params params - = dispatch_target_arch(target_arch); + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const auto params = get_config(block_merge_config{}, current_target); const bool use_mergepath = size > params.merge_oddeven_config.size_limit; const unsigned int merge_mergepath_items_per_block = params.merge_mergepath_config.block_size * params.merge_mergepath_config.items_per_thread; @@ -767,10 +790,12 @@ inline hipError_t merge_sort_impl( // Virtual shared memory part void* vsmem = nullptr; - size_t virtual_shared_memory_size - = get_merge_sort_vsmem_size( - target_arch, - size); + size_t virtual_shared_memory_size = get_merge_sort_vsmem_size(current_target, size); // temporary storage needed for both block merge and block sort size_t* d_merge_partitions = nullptr; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort_config.hpp index 4883c4a08a8..4bf89d0faea 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort_config.hpp @@ -86,77 +86,55 @@ namespace detail { // Sub algorithm block_merge: - -template -struct wrapped_merge_sort_block_merge_config +template +struct merge_oddeven_config_static_selector { - template - struct architecture_config - { - static constexpr merge_sort_block_merge_config_params params = MergeSortBlockMergeConfig(); - }; + static constexpr auto block_size + = target_config::params.merge_oddeven_config.block_size; }; -template -struct wrapped_merge_sort_block_merge_config +template +struct merge_mergepath_partition_config_static_selector { - template - struct architecture_config - { - static constexpr merge_sort_block_merge_config_params params - = default_merge_sort_block_merge_config(Arch), Key, Value>(); - }; + static constexpr auto block_size = target_config::params + .merge_mergepath_partition_config.block_size; }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr merge_sort_block_merge_config_params - wrapped_merge_sort_block_merge_config:: - architecture_config::params; - -template -template -constexpr merge_sort_block_merge_config_params - wrapped_merge_sort_block_merge_config::architecture_config< - Arch>::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - -// Sub-algorithm block_sort: -template -struct wrapped_merge_sort_block_sort_config +template +struct merge_mergepath_config_static_selector { - template - struct architecture_config - { - static constexpr merge_sort_block_sort_config_params params = MergeSortBlockSortConfig(); - }; + static constexpr auto block_size + = target_config::params.merge_mergepath_config.block_size; }; -template -struct wrapped_merge_sort_block_sort_config +template +struct merge_sort_block_merge_config_selector { - template - struct architecture_config - { - static constexpr merge_sort_block_sort_config_params params - = default_merge_sort_block_sort_config(Arch), Key, Value>(); - }; + using targets = merge_sort_block_merge_targets; + using param_type = merge_sort_block_merge_config_params; + + param_type params; + + template + constexpr merge_sort_block_merge_config_selector(Target) + : params(merge_sort_block_merge_config_picker()) + {} }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr merge_sort_block_sort_config_params - wrapped_merge_sort_block_sort_config::architecture_config< - Arch>::params; - -template -template -constexpr merge_sort_block_sort_config_params - wrapped_merge_sort_block_sort_config::architecture_config< - Arch>::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS +// Sub-algorithm block_sort: +template +struct merge_sort_block_sort_config_selector +{ + using targets = merge_sort_block_sort_targets; + using param_type = merge_sort_block_sort_config_params; + + param_type params; + + template + constexpr merge_sort_block_sort_config_selector(Target) + : params(merge_sort_block_sort_config_picker()) + {} +}; } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_nth_element.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_nth_element.hpp index d0bdc770756..25041a9fc6c 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_nth_element.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_nth_element.hpp @@ -58,7 +58,7 @@ hipError_t = nullptr) { using key_type = typename std::iterator_traits::value_type; - using config = wrapped_nth_element_config; + using selector = nth_element_config_selector; bool use_atomic_block_id; ROCPRIM_RETURN_ON_ERROR(check_if_using_atomic_block_id(stream, use_atomic_block_id)); @@ -70,17 +70,15 @@ hipError_t [&](auto use_atomic_block_id) { target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const nth_element_config_params params - = dispatch_target_arch(target_arch); - - constexpr unsigned int num_partitions = 3; - const unsigned int num_buckets = params.number_of_buckets; - const unsigned int num_splitters = num_buckets - 1; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); + constexpr unsigned int num_partitions = 3; + const unsigned int num_buckets = params.number_of_buckets; + const unsigned int num_splitters = num_buckets - 1; const unsigned int stop_recursion_size = params.stop_recursion_size; const unsigned int num_items_per_threads = params.kernel_config.items_per_thread; const unsigned int num_threads_per_block = params.kernel_config.block_size; @@ -163,24 +161,24 @@ hipError_t auto ordered_bid = ordered_bid_type::create(ordered_bid_storage); - return nth_element_keys_impl(target_arch, - keys, - keys_buffer, - tree, - nth, - size, - buckets, - equality_buckets, - lookback_states, - num_buckets, - stop_recursion_size, - num_threads_per_block, - num_items_per_threads, - nth_element_data, - compare_function, - stream, - debug_synchronous, - ordered_bid); + return nth_element_keys_impl(current_target, + keys, + keys_buffer, + tree, + nth, + size, + buckets, + equality_buckets, + lookback_states, + num_buckets, + stop_recursion_size, + num_threads_per_block, + num_items_per_threads, + nth_element_data, + compare_function, + stream, + debug_synchronous, + ordered_bid); }, use_atomic_block_id_variant)); return hipSuccess; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_nth_element_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_nth_element_config.hpp index 3d18bd7e146..c8bb024b7e1 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_nth_element_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_nth_element_config.hpp @@ -35,41 +35,24 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// generic struct that instantiates custom configurations -template -struct wrapped_nth_element_config +template +struct nth_element_config_selector { - template - struct architecture_config - { - static constexpr nth_element_config_params params = NthElementConfig{}; - }; + // Targets can not be fully empty. + using targets + = comp_targets>; + using param_type = nth_element_config_params; + + param_type params; + + template + constexpr nth_element_config_selector(Target) + : params(param_type{ + 64, 64, block_radix_rank_algorithm::match, kernel_config_params{512, 8} + }) + {} }; -// specialized for rocprim::default_config, which instantiates the default_nth_element_config -template -struct wrapped_nth_element_config -{ - template - struct architecture_config - { - static constexpr nth_element_config_params params - = {64, 64, block_radix_rank_algorithm::match, kernel_config<512, 8>()}; - }; -}; - -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr nth_element_config_params - wrapped_nth_element_config::architecture_config::params; - -template -template -constexpr nth_element_config_params - wrapped_nth_element_config::architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp index 0f9db2d5ad9..55e58c1a0a3 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp @@ -52,6 +52,7 @@ namespace detail { template -inline size_t get_partition_vsmem_size_per_block(detail::target_arch arch) +inline size_t get_partition_vsmem_size_per_block(detail::target t) { + using targets = typename Selector::targets; using offset_type = typename OffsetLookbackScanState::value_type; - std::optional vsmem_per_block; - for_each_arch( - [&](auto arch_tag) + + size_t vsmem_per_block = 0; + + targets::for_each( + [&](auto candidate) { - constexpr target_arch Arch = decltype(arch_tag)::value; - if(Arch != arch || vsmem_per_block) + if(target{candidate} == most_common_config(t)) { - return; + using ArchConfig = target_config; + using partition_kernel_impl_t = partition_kernel_impl_; + + using partition_vsmem_helper_t = detail::vsmem_helper_impl; + vsmem_per_block = partition_vsmem_helper_t::vsmem_per_block; } - - using ArchConfig = typename Config::template architecture_config; - using partition_kernel_impl_t = partition_kernel_impl_; - using partition_vsmem_helper_t = detail::vsmem_helper_impl; - - vsmem_per_block = partition_vsmem_helper_t::vsmem_per_block; }); - if(!vsmem_per_block) - { - using ArchConfig = typename Config::template architecture_config; - using partition_kernel_impl_t = partition_kernel_impl_; - using partition_vsmem_helper_t = detail::vsmem_helper_impl; - - vsmem_per_block = partition_vsmem_helper_t::vsmem_per_block; - } - return vsmem_per_block.value(); + + return vsmem_per_block; } template; using block_id_type = detail::block_id_wrapper; - using config = wrapped_partition_config; - constexpr bool write_only_selected = SubAlgo == partition_subalgo::select_flag || SubAlgo == partition_subalgo::select_predicate @@ -173,10 +158,19 @@ inline hipError_t partition_impl(void* temporary_storage, ? select_method::predicated_flag : (is_flag ? select_method::flag : select_method::predicate)); + using flag_type = + typename std::conditional::value_type, + bool>::type; + using selector = partition_config_selector; + detail::target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - const partition_config_params params = dispatch_target_arch(target_arch); + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + const target current_target(target_arch, target_gpu); + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const auto items_per_block = block_size * items_per_thread; @@ -208,20 +202,17 @@ inline hipError_t partition_impl(void* temporary_storage, // vsmem size void* vsmem = nullptr; size_t virtual_shared_memory_size = 0; - using flag_type = - typename std::conditional::value_type, - bool>::type; virtual_shared_memory_size - = get_partition_vsmem_size_per_block(target_arch); + block_id_type>(current_target); virtual_shared_memory_size *= number_of_blocks; // temporary storage partition @@ -346,12 +337,13 @@ inline hipError_t partition_impl(void* temporary_storage, storage, predicates...); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - partition_kernel, - dim3(current_number_of_blocks), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan(current_target, + partition_kernel, + dim3(current_number_of_blocks), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("partition_kernel", size, start); std::swap(selected_count, prev_selected_count); diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_partition_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_partition_config.hpp index 69c92f0d7ac..db923d87fcc 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_partition_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_partition_config.hpp @@ -42,245 +42,108 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -struct wrapped_partition_config +template +constexpr auto algo_target_type() { - template - struct architecture_config + if constexpr(Algo == partition_subalgo::partition_two_way_predicate) { - static constexpr partition_config_params params = PartitionConfig{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::partition_two_way_flag) { - static constexpr partition_config_params params - = default_partition_two_way_predicate_config(Arch), - KeyType>{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::partition_flag) { - static constexpr partition_config_params params - = default_partition_two_way_flag_config(Arch), KeyType>{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::partition_predicate) { - static constexpr partition_config_params params - = default_partition_flag_config(Arch), KeyType>{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::partition_three_way) { - static constexpr partition_config_params params - = default_partition_predicate_config(Arch), KeyType>{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::select_flag) { - static constexpr partition_config_params params - = default_partition_three_way_config(Arch), KeyType>{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::select_predicate) { - static constexpr partition_config_params params - = default_select_flag_config(Arch), KeyType>{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::select_predicated_flag) { - static constexpr partition_config_params params - = default_select_predicate_config(Arch), KeyType>{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::select_unique) { - static constexpr partition_config_params params - = default_select_predicated_flag_config(Arch), - KeyType, - ValueType>{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::select_unique_by_key) { - static constexpr partition_config_params params - = default_select_unique_config(Arch), KeyType>{}; - }; -}; + return type_identity{}; + } +} -template -struct wrapped_partition_config +template +struct partition_config_selector { - template - struct architecture_config - { - static constexpr partition_config_params params - = default_select_unique_by_key_config(Arch), - KeyType, - ValueType>{}; - }; -}; - -#ifndef DOXYGEN_SHOULD_SKIP_THIS - -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config< - Arch>::params; - -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; + using targets = typename decltype(algo_target_type())::type; + using param_type = partition_config_params; -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; + param_type params; -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; - -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; - -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; - -template -template -constexpr partition_config_params - wrapped_partition_config:: - architecture_config::params; - -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; - -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; - -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; - -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; - -#endif // DOXYGEN_SHOULD_SKIP_THIS + template + constexpr param_type picker_helper() + { + if constexpr(SubAlgo == partition_subalgo::partition_two_way_predicate) + { + return partition_two_way_predicate_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::partition_two_way_flag) + { + return partition_two_way_flag_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::partition_flag) + { + return partition_flag_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::partition_predicate) + { + return partition_predicate_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::partition_three_way) + { + return partition_three_way_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::select_flag) + { + return select_flag_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::select_predicate) + { + return select_predicate_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::select_predicated_flag) + { + return select_predicated_flag_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::select_unique) + { + return select_unique_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::select_unique_by_key) + { + return select_unique_by_key_config_picker(); + } + } + + template + constexpr partition_config_selector(Target) : params(picker_helper()) + {} +}; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort.hpp index db3a92d33ad..37fee2bf475 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort.hpp @@ -36,6 +36,7 @@ #include "../types.hpp" #include "../type_traits.hpp" +#include "config_types.hpp" #include "detail/config/device_radix_sort_onesweep.hpp" #include "detail/device_radix_sort.hpp" #include "device_transform.hpp" @@ -97,16 +98,18 @@ hipError_t radix_sort_onesweep_global_offsets(KeysInputIterator keys_input, { using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; - using config = wrapped_radix_sort_onesweep_config; - detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const radix_sort_onesweep_config_params params - = dispatch_target_arch(target_arch); + using Selector = radix_sort_onesweep_config_selector; + + target_arch target_arch; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const radix_sort_onesweep_config_params params = get_config(Config{}, current_target); const unsigned int items_per_block = params.histogram.block_size * params.histogram.items_per_thread; @@ -147,15 +150,15 @@ hipError_t radix_sort_onesweep_global_offsets(KeysInputIterator keys_input, begin_bit, end_bit); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan( - target_arch, - onesweep_histograms_kernel, - dim3(blocks), - dim3(params.histogram.block_size), - 0, - stream)); + + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan( + current_target, + onesweep_histograms_kernel, + dim3(blocks), + dim3(params.histogram.block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("compute_global_digit_histograms", size, start); // Scan each histogram separately to get the final offsets. @@ -170,15 +173,16 @@ hipError_t radix_sort_onesweep_global_offsets(KeysInputIterator keys_input, onesweep_scan_histograms( global_digit_offsets); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan( - target_arch, - onesweep_scan_histograms_kernel, - dim3(digit_places), // One block for every digit place. - dim3(params.histogram.block_size), - 0, - stream)); + + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan( + current_target, + onesweep_scan_histograms_kernel, + dim3(digit_places), // One block for every digit place. + dim3(params.histogram.block_size), + 0, + stream)); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("scan_global_digit_histograms", bins, start); return hipSuccess; } @@ -214,12 +218,18 @@ hipError_t radix_sort_onesweep_iteration( { using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; - using config = wrapped_radix_sort_onesweep_config; - detail::target_arch target_arch; + using Selector = radix_sort_onesweep_config_selector; + + target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - const radix_sort_onesweep_config_params params - = dispatch_target_arch(target_arch); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const radix_sort_onesweep_config_params params = get_config(Config{}, current_target); const unsigned int items_per_block = params.sort.block_size * params.sort.items_per_thread; const unsigned int current_radix_bits @@ -278,7 +288,8 @@ hipError_t radix_sort_onesweep_iteration( { auto onesweep_iteration_kernel = [=](auto arch_config) { - static constexpr auto params = decltype(arch_config)::params; + static constexpr radix_sort_onesweep_config_params params + = decltype(arch_config)::params; onesweep_iteration( - target_arch, + return execute_launch_plan( + current_target, onesweep_iteration_kernel, dim3(blocks), dim3(params.sort.block_size), @@ -373,6 +384,8 @@ hipError_t radix_sort_onesweep_impl( using value_type = typename std::iterator_traits::value_type; using offset_type = offset_type_t; + using Selector = radix_sort_onesweep_config_selector; + bool use_atomic_block_id; ROCPRIM_RETURN_ON_ERROR(check_if_using_atomic_block_id(stream, use_atomic_block_id)); const auto use_atomic_block_id_variant @@ -383,13 +396,17 @@ hipError_t radix_sort_onesweep_impl( [&](auto use_atomic_block_id) { using ordered_bid_type = block_id_wrapper; - using config = wrapped_radix_sort_onesweep_config; - detail::target_arch target_arch; + target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + const radix_sort_onesweep_config_params params - = dispatch_target_arch(target_arch); + = get_config(Config{}, current_target); const unsigned int sort_items_per_block = params.sort.block_size * params.sort.items_per_thread; @@ -593,11 +610,11 @@ hipError_t constexpr bool use_default_small_block_sort = is_default_config || std::is_same::value; - using default_radix_sort_block_sort_config = - typename radix_sort_block_sort_config_base::type; + constexpr auto default_radix_sort_block_sort_config + = radix_sort_block_sort_config_params_base(); using default_block_sort_config - = kernel_config; + = kernel_config; using block_sort_config = typename std::conditional::type; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort_config.hpp index 52ac827660a..a8ce230d97b 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort_config.hpp @@ -23,6 +23,7 @@ #include "config_types.hpp" #include "detail/config/device_radix_sort_block_sort.hpp" +#include "detail/config/device_radix_sort_onesweep.hpp" #include "detail/device_config_helper.hpp" /// \addtogroup primitivesmodule_deviceconfigs @@ -65,76 +66,53 @@ struct radix_sort_config namespace detail { -// sub-algorithm onesweep: -template -struct wrapped_radix_sort_onesweep_config +template +struct radix_sort_onesweep_config_selector { - template - struct architecture_config - { - static constexpr radix_sort_onesweep_config_params params = RadixSortOnesweepConfig(); - }; + using targets = radix_sort_onesweep_targets; + using param_type = radix_sort_onesweep_config_params; + + param_type params; + + template + constexpr radix_sort_onesweep_config_selector(Target) + : params(radix_sort_onesweep_config_picker()) + {} }; -template -struct wrapped_radix_sort_onesweep_config +template +struct radix_sort_onesweep_histogram_config_static_selector { - template - struct architecture_config - { - static constexpr radix_sort_onesweep_config_params params - = default_radix_sort_onesweep_config(Arch), Key, Value>(); - }; + static constexpr auto block_size = target_config::params + .radix_sort_onesweep_config_params::histogram.block_size; }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr radix_sort_onesweep_config_params - wrapped_radix_sort_onesweep_config::architecture_config< - Arch>::params; - -template -template -constexpr radix_sort_onesweep_config_params - wrapped_radix_sort_onesweep_config::architecture_config< - Arch>::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - -// Sub-algorithm block_sort: -template -struct wrapped_radix_sort_block_sort_config +template +struct radix_sort_onesweep_sort_config_static_selector { - template - struct architecture_config - { - static constexpr kernel_config_params params = RadixSortBlockSortConfig(); - }; + static constexpr auto block_size = target_config::params + .radix_sort_onesweep_config_params::sort.block_size; }; -template -struct wrapped_radix_sort_block_sort_config +template +struct radix_sort_block_sort_config_selector { - template - struct architecture_config - { - static constexpr kernel_config_params params - = default_radix_sort_block_sort_config(Arch), Key, Value>(); - }; + using targets = radix_sort_block_sort_targets; + using param_type = kernel_config_params; + + param_type params; + + template + constexpr radix_sort_block_sort_config_selector(Target) + : params(radix_sort_block_sort_config_picker()) + {} }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr kernel_config_params - wrapped_radix_sort_block_sort_config::architecture_config< - Arch>::params; - -template -template -constexpr kernel_config_params wrapped_radix_sort_block_sort_config:: - architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS +template +struct radix_sort_block_sort_config_static_selector +{ + static constexpr auto block_size = target_config::params.block_size; +}; } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_reduce.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_reduce.hpp index 02c1a7043c5..c4a3da51d83 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_reduce.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_reduce.hpp @@ -80,12 +80,12 @@ namespace detail result_type>(input, size, output, initial_value, reduce_op); \ }; \ \ - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, \ - block_reduce_kernel, \ - dim3(1), \ - dim3(block_size), \ - 0, \ - stream)); \ + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, \ + block_reduce_kernel, \ + dim3(1), \ + dim3(block_size), \ + 0, \ + stream)); \ ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("block_reduce_kernel", size, start); \ } \ while(0) @@ -109,12 +109,17 @@ inline hipError_t reduce_impl(void* temporary_storage, using input_type = typename std::iterator_traits::value_type; using result_type = ::rocprim::accumulator_t; - using config = wrapped_reduce_config; + using Selector = reduce_config_selector; - detail::target_arch target_arch; + target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - const reduce_config_params params = dispatch_target_arch(target_arch); + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -196,12 +201,12 @@ inline hipError_t reduce_impl(void* temporary_storage, initial_value, reduce_op); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - block_reduce_kernel, - dim3(current_blocks), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + block_reduce_kernel, + dim3(current_blocks), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("block_reduce_kernel", current_size, start); } diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key.hpp index 666470a92fc..4b09981c36c 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key.hpp @@ -97,7 +97,8 @@ ROCPRIM_KERNEL ROCPRIM_LAUNCH_BOUNDS(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) void } template(target_arch); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); using scan_state_type = reduce_by_key::lookback_scan_state_t; @@ -268,12 +274,13 @@ hipError_t reduce_by_key_impl_wrapped_config(void* temporary i > 0 ? d_previous_accumulated : nullptr, ordered_bid); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - kernel, - dim3(number_of_blocks_launch), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan(current_target, + kernel, + dim3(number_of_blocks_launch), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("reduce_by_key_kernel", current_size, start); @@ -310,21 +317,21 @@ hipError_t reduce_by_key_impl(void* temporary_storage, { using key_type = ::rocprim::detail::value_type_t; using accumulator_type = reduce_by_key::accumulator_type_t; + using Selector = reduce_by_key_config_selector; - using config = wrapped_reduce_by_key_config; - - return detail::reduce_by_key_impl_wrapped_config(temporary_storage, - storage_size, - keys_input, - values_input, - size, - unique_output, - aggregates_output, - unique_count_output, - reduce_op, - key_compare_op, - stream, - debug_synchronous); + return detail::reduce_by_key_impl_wrapped_config( + temporary_storage, + storage_size, + keys_input, + values_input, + size, + unique_output, + aggregates_output, + unique_count_output, + reduce_op, + key_compare_op, + stream, + debug_synchronous); } } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key_config.hpp index 9b028a64434..00b13a25601 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key_config.hpp @@ -33,84 +33,34 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// Generic for user-provided config: instantiate user-provided config. -template -struct wrapped_reduce_by_key_config +template +struct reduce_by_key_config_selector { - template - struct architecture_config - { - static constexpr reduce_by_key_config_params params = ReduceByKeyConfig{}; - }; -}; + using targets = reduce_by_key_targets; + using param_type = reduce_by_key_config_params; -// Generic for default config: instantiate base config. -template -struct wrapped_reduce_by_key_impl -{ - template - struct architecture_config - { - static constexpr reduce_by_key_config_params params = - typename default_reduce_by_key_config_base::type{}; - }; -}; + param_type params; -// Specialization for default config if types are not custom: instantiate the tuned config. -template -struct wrapped_reduce_by_key_impl< - KeyType, - AccumulatorType, - BinaryFunction, - std::enable_if_t::value && is_arithmetic::value - && is_binary_functional::value>> -{ - template - struct architecture_config + template + constexpr param_type picker_helper() { - static constexpr reduce_by_key_config_params params - = default_reduce_by_key_config(Arch), - KeyType, - AccumulatorType>{}; - }; + // Specialization for default config if types are not custom: instantiate the tuned config. + if constexpr(rocprim::is_arithmetic::value && rocprim::is_arithmetic::value + && rocprim::detail::is_binary_functional::value) + { + return reduce_by_key_config_picker(); + } + else + { + return reduce_by_key_config_params_base(); + } + } + + template + constexpr reduce_by_key_config_selector(Target) : params(picker_helper()) + {} }; -// Specialization for default config. -template -struct wrapped_reduce_by_key_config - : wrapped_reduce_by_key_impl -{}; - -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr reduce_by_key_config_params - wrapped_reduce_by_key_config:: - architecture_config::params; - -template -template -constexpr reduce_by_key_config_params - wrapped_reduce_by_key_impl:: - architecture_config::params; - -template -template -constexpr reduce_by_key_config_params wrapped_reduce_by_key_impl< - KeyType, - AccumulatorType, - BinaryFunction, - std::enable_if_t::value && is_arithmetic::value - && is_binary_functional::value>>::architecture_config:: - params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // end namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_reduce_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_reduce_config.hpp index 45f0464536f..c41acc8690a 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_reduce_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_reduce_config.hpp @@ -33,42 +33,19 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -struct wrapped_reduce_config +template +struct reduce_config_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template reduce_config"); + using targets = reduce_targets; + using param_type = reduce_config_params; - template - struct architecture_config - { - static constexpr reduce_config_params params = ReduceConfig(); - }; -}; + param_type params; -template -struct wrapped_reduce_config -{ - template - struct architecture_config - { - static constexpr reduce_config_params params - = default_reduce_config(Arch), Value>(); - }; + template + constexpr reduce_config_selector(Target) : params(reduce_config_picker()) + {} }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr reduce_config_params - wrapped_reduce_config::architecture_config::params; - -template -template -constexpr reduce_config_params - wrapped_reduce_config::architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode.hpp index 6d6a4720878..7fd67e97220 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode.hpp @@ -76,23 +76,23 @@ hipError_t run_length_encode_impl(void* temporary_storage, { using key_type = ::rocprim::detail::value_type_t; using accumulator_type = reduce_by_key::accumulator_type_t; - - using config = wrapped_trivial_runs_config; + using Selector = run_length_encode_config_selector; return detail::reduce_by_key_impl_wrapped_config< detail::lookback_scan_determinism::nondeterministic, - config>(temporary_storage, - storage_size, - keys_input, - values_input, - size, - unique_output, - aggregates_output, - unique_count_output, - reduce_op, - key_compare_op, - stream, - debug_synchronous); + typename select_reduce_by_key_config::type, + Selector>(temporary_storage, + storage_size, + keys_input, + values_input, + size, + unique_output, + aggregates_output, + unique_count_output, + reduce_op, + key_compare_op, + stream, + debug_synchronous); } template; // accumulator_type + // RLE config needs to be converted to non_trivial_runs_config. + using non_trivial_config = typename convert_to_non_trivial_config::type; + using Selector = run_length_encode_non_trivial_config_selector; bool use_atomic_block_id; ROCPRIM_RETURN_ON_ERROR(check_if_using_atomic_block_id(stream, use_atomic_block_id)); @@ -130,16 +133,19 @@ hipError_t run_length_encode_non_trivial_runs_impl(void* tempo ROCPRIM_RETURN_ON_ERROR(std::visit( [&](auto use_sleepy_scan, auto use_atomic_block_id) { - using config = rocprim::detail::wrapped_non_trivial_runs_config; + target_arch target_arch; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const auto params = get_config(non_trivial_config{}, current_target); using scan_state_type = ::rocprim::detail::lookback_scan_state; - detail::target_arch target_arch; - ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - - const non_trivial_runs_config_params params - = dispatch_target_arch(target_arch); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_block = block_size * params.kernel_config.items_per_thread; const std::size_t grid_size = detail::ceiling_div(size, items_per_block); @@ -221,12 +227,13 @@ hipError_t run_length_encode_non_trivial_runs_impl(void* tempo ordered_bid); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - non_trivial_kernel, - dim3(grid_size), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan(current_target, + non_trivial_kernel, + dim3(grid_size), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("run_length_encode::non_trivial_kernel", size, start); diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode_config.hpp index 708e89be3ee..5da14c7f4c5 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode_config.hpp @@ -55,180 +55,94 @@ struct run_length_encode_config namespace detail { -template -struct wrapped_trivial_runs_config - : wrapped_reduce_by_key_config -{}; - -template -struct wrapped_trivial_runs_config< - rocprim::run_length_encode_config, - KeyType, - AccumulatorType, - BinaryFunction> - : wrapped_reduce_by_key_config -{}; - -template -struct wrapped_trivial_runs_impl - : wrapped_reduce_by_key_impl -{}; - -template -struct wrapped_trivial_runs_impl< - KeyType, - AccumulatorType, - BinaryFunction, - std::enable_if_t::value && is_arithmetic::value - && is_binary_functional::value>> +template +struct select_reduce_by_key_config { - template - struct architecture_config - { - static constexpr reduce_by_key_config_params params - = default_trivial_runs_config(Arch), - KeyType, - AccumulatorType>{}; - }; + using type = typename Config::reduce_by_key; }; -template -struct wrapped_trivial_runs_config - : wrapped_trivial_runs_impl -{}; +template<> +struct select_reduce_by_key_config +{ + using type = rocprim::default_config; +}; -// Wrap around run_length_encode_config and the newly added non_trivial_runs_config for the -// run_length_encode_non_trivial_runs algorithm. Three cases are considered for selecting -// the appropriate config: -// -// - When a run_length_encode_config struct is passed as argument, an specialization of -// this struct takes care of mapping the parameters of that config to the newly added -// non_trivial_runs_config. -// -// - When a default config is passed, another specialization takes care of using the -// default set up of non_trivial_runs_config. -// -// - When a non_trivial_runs_config is passed, the params are set from this config. -// -template -struct wrapped_non_trivial_runs_config +template +struct run_length_encode_config_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template non_trivial_runs_config"); + using targets = run_length_encode_targets; + using param_type = reduce_by_key_config_params; + + param_type params; - template - struct architecture_config + template + constexpr param_type picker_helper() { - static constexpr non_trivial_runs_config_params params = RLENonTrivialRunsConfig{}; - }; + // Specialization for default config if types are not custom: instantiate the tuned config. + if constexpr(rocprim::is_arithmetic::value && rocprim::is_arithmetic::value + && rocprim::detail::is_binary_functional::value) + { + return run_length_encode_config_picker(); + } + else + { + return reduce_by_key_config_params_base(); + } + } + + template + constexpr run_length_encode_config_selector(Target) : params(picker_helper()) + {} }; -template -struct wrapped_non_trivial_runs_config< - rocprim::run_length_encode_config, - InputType> +template +struct convert_to_non_trivial_config { - template - struct architecture_config - { - // Mapping to non_trivial_runs_config. - // Beware that this mapping may impact performance of executions of - // run_length_encode_non_trivial_runs with the former run_length_encode_config, - // as it may not be the best for all cases. - static constexpr unsigned int block_size = ReduceByKeyConfig::block_size; - static constexpr unsigned int items_per_thread = ReduceByKeyConfig::items_per_thread; - - static constexpr block_load_method load_input_method = ReduceByKeyConfig::load_keys_method; - static constexpr block_scan_algorithm scan_algorithm = ReduceByKeyConfig::scan_algorithm; - - static constexpr non_trivial_runs_config_params params - = non_trivial_runs_config{}; - }; + using ReduceByKeyConfig = typename Config::reduce_by_key; + static constexpr unsigned int block_size = ReduceByKeyConfig::block_size; + static constexpr unsigned int items_per_thread = ReduceByKeyConfig::items_per_thread; + + static constexpr block_load_method load_input_method = ReduceByKeyConfig::load_keys_method; + static constexpr block_scan_algorithm scan_algorithm = ReduceByKeyConfig::scan_algorithm; + + using type + = non_trivial_runs_config; }; -// Generic for default config: instantiate base config. -template -struct wrapped_non_trivial_runs_impl +template<> +struct convert_to_non_trivial_config { - template - struct architecture_config - { - static constexpr non_trivial_runs_config_params params = - typename default_non_trivial_runs_config_base::type{}; - }; + using type = rocprim::default_config; }; -// Specialization for default config if types are arithmetic or half/bfloat16-precision -// floating point types: instantiate the tuned config. -template -struct wrapped_non_trivial_runs_impl::value>> +template +struct run_length_encode_non_trivial_config_selector { - template - struct architecture_config + using targets = run_length_encode_non_trivial_targets; + using param_type = non_trivial_runs_config_params; + + param_type params; + + template + constexpr param_type picker_helper() { - static constexpr non_trivial_runs_config_params params - = default_non_trivial_runs_config(Arch), InputType>{}; - }; + // Specialization for default config if types are not custom: instantiate the tuned config. + if constexpr(rocprim::is_arithmetic::value) + { + return run_length_encode_non_trivial_config_picker(); + } + else + { + return non_trivial_runs_config_params_base(); + } + } + + template + constexpr run_length_encode_non_trivial_config_selector(Target) + : params(picker_helper()) + {} }; -// Specialization for default config. -template -struct wrapped_non_trivial_runs_config - : wrapped_non_trivial_runs_impl -{}; - -#ifndef DOXYGEN_DOCUMENTATION_BUILD - -template -template -constexpr reduce_by_key_config_params wrapped_trivial_runs_impl< - KeyType, - AccumulatorType, - BinaryFunction, - std::enable_if_t::value && is_arithmetic::value - && is_binary_functional::value>>::architecture_config:: - params; - -template -template -constexpr non_trivial_runs_config_params - wrapped_non_trivial_runs_config::architecture_config::params; - -template -template -constexpr non_trivial_runs_config_params wrapped_non_trivial_runs_config< - rocprim::run_length_encode_config, - InputType>::architecture_config::params; - -template -template -constexpr non_trivial_runs_config_params - wrapped_non_trivial_runs_impl::architecture_config::params; - -template -template -constexpr non_trivial_runs_config_params wrapped_non_trivial_runs_impl< - InputType, - std::enable_if_t::value>>::architecture_config::params; - -#endif // DOXYGEN_DOCUMENTATION_BUILD - } // end namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_scan.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_scan.hpp index 9c05485e49b..06c7690427a 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_scan.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_scan.hpp @@ -86,12 +86,18 @@ inline auto scan_impl(void* temporary_storage, scan_state_type scan_state; block_id_type block_id; - using config = wrapped_scan_config; + using Selector = scan_config_selector; detail::target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - const auto params = dispatch_target_arch(target_arch); + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); + const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -213,12 +219,13 @@ inline auto scan_impl(void* temporary_storage, (number_of_launch > 1), block_id); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - lookback_scan_kernel, - dim3(grid_size), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan(current_target, + lookback_scan_kernel, + dim3(grid_size), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("lookback_scan_kernel", current_size, start); @@ -285,12 +292,12 @@ inline auto scan_impl(void* temporary_storage, // Save values into output array block_store_type().store(output, values, size, storage.store); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - single_scan_kernel, - dim3(1), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + single_scan_kernel, + dim3(1), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("single_scan_kernel", size, start); } return hipSuccess; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key.hpp index 7b2862060ee..fc755f4ee15 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key.hpp @@ -171,12 +171,15 @@ inline hipError_t scan_by_key_impl(void* const temporary_storage, ROCPRIM_RETURN_ON_ERROR(std::visit( [&](auto use_sleepy_scan, auto use_atomic_block_id) { - using config = wrapped_scan_by_key_config; + using Selector = scan_by_key_config_selector; detail::target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - const scan_by_key_config_params params - = dispatch_target_arch(target_arch); + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + const auto params = get_config(Config{}, current_target); using wrapped_type = ::rocprim::tuple; using scan_state_type = detail::lookback_scan_state; @@ -336,12 +339,13 @@ inline hipError_t scan_by_key_impl(void* const temporary_storage, last_keys_of_each_block, ordered_bid); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - device_scan_by_key_kernel, - dim3(scan_blocks), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan(current_target, + device_scan_by_key_kernel, + dim3(scan_blocks), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("device_scan_by_key_kernel", current_size, start); diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key_config.hpp index 7018b874d8f..5f1800ef8ac 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key_config.hpp @@ -32,42 +32,20 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -struct wrapped_scan_by_key_config +template +struct scan_by_key_config_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template scan_by_key_config"); + using targets = scan_by_key_targets; + using param_type = scan_by_key_config_params; - template - struct architecture_config - { - static constexpr scan_by_key_config_params params = ScanByKeyConfig{}; - }; -}; + param_type params; -template -struct wrapped_scan_by_key_config -{ - template - struct architecture_config - { - static constexpr scan_by_key_config_params params - = default_scan_by_key_config(Arch), Key, Value>{}; - }; + template + constexpr scan_by_key_config_selector(Target) + : params(scan_by_key_config_picker()) + {} }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr scan_by_key_config_params - wrapped_scan_by_key_config::architecture_config::params; - -template -template -constexpr scan_by_key_config_params - wrapped_scan_by_key_config::architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_scan_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_scan_config.hpp index f2a4254da20..10289469d7d 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_scan_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_scan_config.hpp @@ -32,40 +32,18 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -struct wrapped_scan_config +template +struct scan_config_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template scan_config"); - template - struct architecture_config - { - static constexpr scan_config_params params = ScanConfig{}; - }; -}; + using targets = scan_targets; + using param_type = scan_config_params; -template -struct wrapped_scan_config -{ - template - struct architecture_config - { - static constexpr scan_config_params params - = default_scan_config(Arch), Value>{}; - }; -}; + param_type params; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr scan_config_params - wrapped_scan_config::architecture_config::params; - -template -template -constexpr scan_config_params - wrapped_scan_config::architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS + template + constexpr scan_config_selector(Target) : params(scan_config_picker()) + {} +}; } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_search_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_search_config.hpp index c6b23ea42ef..f50704d30eb 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_search_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_search_config.hpp @@ -33,40 +33,24 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// generic struct that instantiates custom configurations -template -struct wrapped_search_config +template +struct search_config_selector { - template - struct architecture_config - { - static constexpr search_config_params params = Config{}; - }; + // Targets can not be fully empty. + using targets + = comp_targets>; + using param_type = search_config_params; + + param_type params; + + template + constexpr search_config_selector(Target) + : params(param_type{ + 2048, kernel_config_params{256, 4} + }) + {} }; -// specialized for rocprim::default_config, which instantiates the default_search_config -template -struct wrapped_search_config -{ - template - struct architecture_config - { - static constexpr search_config_params params = {2048, kernel_config<256, 4>()}; - }; -}; - -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr search_config_params - wrapped_search_config::architecture_config::params; - -template -template -constexpr search_config_params - wrapped_search_config::architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_search_n_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_search_n_config.hpp index 1d645eed3f6..1a4145ea184 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_search_n_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_search_n_config.hpp @@ -30,40 +30,18 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// generic struct that instantiates custom configurations -template -struct wrapped_search_n_config +template +struct search_n_config_selector { - template - struct architecture_config - { - static constexpr search_n_config_params params = Config{}; - }; -}; + using targets = search_n_targets; + using param_type = search_n_config_params; -// specialized for rocprim::default_config, which instantiates the default_search_n_config -template -struct wrapped_search_n_config -{ - template - struct architecture_config - { - static constexpr search_n_config_params params - = default_search_n_config(Arch), Value>{}; - }; -}; + param_type params; -#ifndef DOXYGEN_DOCUMENTATION_BUILD -template -template -constexpr search_n_config_params - wrapped_search_n_config::architecture_config::params; - -template -template -constexpr search_n_config_params - wrapped_search_n_config::architecture_config::params; -#endif // DOXYGEN_DOCUMENTATION_BUILD + template + constexpr search_n_config_selector(Target) : params(search_n_config_picker()) + {} +}; } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp index d3d4290d4b1..320ac214bab 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp @@ -80,7 +80,9 @@ struct Partitioner using input_type = typename std::iterator_traits::value_type; if(three_way_partitioning) { - using config = typename default_partition_config_base::type; + constexpr auto params = partition_config_params_base(); + using config = select_config; return partition_three_way(temporary_storage, storage_size, input, @@ -96,7 +98,9 @@ struct Partitioner } else { - using config = typename default_partition_config_base::type; + constexpr auto params = partition_config_params_base(); + using config = select_config; return partition(temporary_storage, storage_size, input, @@ -150,17 +154,17 @@ inline hipError_t segmented_radix_sort_impl( typename std::iterator_traits::value_type>::value, "ValuesInputIterator and ValuesOutputIterator must have the same value_type"); - using config = wrapped_segmented_radix_sort_config; + using Selector = segmented_radix_sort_config_selector; detail::target_arch target_arch; - hipError_t result = detail::host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); - const detail::segmented_radix_sort_config_params params - = detail::dispatch_target_arch(target_arch); + const auto params = get_config(Config{}, current_target); static constexpr bool with_values = !std::is_same::value; const bool partitioning_allowed = params.warp_sort_config.partitioning_allowed; @@ -353,12 +357,12 @@ inline hipError_t segmented_radix_sort_impl( }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan(target_arch, - segmented_sort_large_kernel, - dim3(large_segment_count), - dim3(params.kernel_config.block_size), - 0, - stream)); + execute_launch_plan(current_target, + segmented_sort_large_kernel, + dim3(large_segment_count), + dim3(params.kernel_config.block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:large_segments", large_segment_count, start); @@ -391,10 +395,10 @@ inline hipError_t segmented_radix_sort_impl( }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan( - target_arch, + execute_launch_plan( + current_target, segmented_sort_medium_kernel, dim3(medium_segment_grid_size), dim3(params.warp_sort_config.block_size_medium), @@ -432,10 +436,10 @@ inline hipError_t segmented_radix_sort_impl( }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan( - target_arch, + execute_launch_plan( + current_target, segmented_sort_small_kernel, dim3(small_segment_grid_size), dim3(params.warp_sort_config.block_size_small), @@ -469,12 +473,13 @@ inline hipError_t segmented_radix_sort_impl( end_bit); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - segmented_sort_kernel, - dim3(segments), - dim3(params.kernel_config.block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan(current_target, + segmented_sort_kernel, + dim3(segments), + dim3(params.kernel_config.block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort", segments, start); } return hipSuccess; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp index f4a1c8de9f3..26aac0a06de 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp @@ -57,46 +57,33 @@ using select_warp_sort_config_t namespace detail { -template -struct wrapped_segmented_radix_sort_config +template +struct segmented_radix_sort_config_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template segmented_radix_sort_config"); - - template - struct architecture_config - { - static constexpr detail::segmented_radix_sort_config_params params - = SegmentedRadixSortConfig{}; - }; + using targets = segmented_radix_sort_targets; + using param_type = segmented_radix_sort_config_params; + + param_type params; + + template + constexpr segmented_radix_sort_config_selector(Target) + : params(segmented_radix_sort_config_picker()) + {} }; -template -struct wrapped_segmented_radix_sort_config +template +struct segmented_radix_sort_warp_sort_small_config_static_selector { - template - struct architecture_config - { - static constexpr segmented_radix_sort_config_params params - = detail::default_segmented_radix_sort_config(Arch), - key_type, - value_type>{}; - }; + static constexpr auto block_size + = target_config::params.warp_sort_config.block_size_small; }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr segmented_radix_sort_config_params - wrapped_segmented_radix_sort_config:: - architecture_config::params; -template -template -constexpr segmented_radix_sort_config_params - wrapped_segmented_radix_sort_config:: - architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS +template +struct segmented_radix_sort_warp_sort_medium_config_static_selector +{ + static constexpr auto block_size + = target_config::params.warp_sort_config.block_size_medium; +}; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce.hpp index 1b6b011023a..c32b6b5186b 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce.hpp @@ -63,15 +63,17 @@ inline hipError_t segmented_reduce_impl(void* temporary_storage, using input_type = typename std::iterator_traits::value_type; using result_type = ::rocprim::accumulator_t; - using config = wrapped_segmented_reduce_config; + using Selector = segmented_reduce_config_selector; - detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const reduce_config_params params = dispatch_target_arch(target_arch); + target_arch target_arch; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; @@ -102,12 +104,12 @@ inline hipError_t segmented_reduce_impl(void* temporary_storage, static_cast(initial_value)); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - segmented_reduce_kernel, - dim3(segments), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + segmented_reduce_kernel, + dim3(segments), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_reduce", segments, start); return hipSuccess; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce_config.hpp index f64800af03c..0b929a18116 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce_config.hpp @@ -33,42 +33,20 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -struct wrapped_segmented_reduce_config +template +struct segmented_reduce_config_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template reduce_config"); + using targets = segmented_reduce_targets; + using param_type = reduce_config_params; - template - struct architecture_config - { - static constexpr reduce_config_params params = ReduceConfig(); - }; -}; + param_type params; -template -struct wrapped_segmented_reduce_config -{ - template - struct architecture_config - { - static constexpr reduce_config_params params - = default_segmented_reduce_config(Arch), Value>(); - }; + template + constexpr segmented_reduce_config_selector(Target) + : params(segmented_reduce_config_picker()) + {} }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr reduce_config_params - wrapped_segmented_reduce_config::architecture_config::params; - -template -template -constexpr reduce_config_params - wrapped_segmented_reduce_config::architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_scan.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_scan.hpp index cac12f2ef82..d40ae1c3837 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_scan.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_scan.hpp @@ -94,15 +94,17 @@ inline hipError_t segmented_scan_impl(void* temporary_storage, using input_type = typename std::iterator_traits::value_type; using result_type = typename std::conditional::type; - using config = wrapped_scan_config; + using Selector = detail::scan_config_selector; detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const scan_config_params params = dispatch_target_arch(target_arch); + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; @@ -133,12 +135,12 @@ inline hipError_t segmented_scan_impl(void* temporary_storage, scan_op); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - segmented_scan_kernel, - dim3(segments), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + segmented_scan_kernel, + dim3(segments), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_scan", segments, start); return hipSuccess; } diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_transform.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_transform.hpp index 80e36fca6bf..8880dcf242f 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_transform.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_transform.hpp @@ -48,6 +48,7 @@ namespace detail template @@ -58,24 +59,23 @@ inline hipError_t transform_impl(InputIterator input, const hipStream_t stream, bool debug_synchronous) { + using input_type = typename std::iterator_traits::value_type; + using result_type = typename ::rocprim::invoke_result::type; + if(size == size_t(0)) { return hipSuccess; } - using input_type = typename std::iterator_traits::value_type; - using result_type = typename ::rocprim::invoke_result::type; + detail::target_arch target_arch; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - using config = detail::wrapped_transform_config; + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - detail::target_arch target_arch; - hipError_t result = detail::host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const detail::transform_config_params params - = detail::dispatch_target_arch(target_arch); + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -123,12 +123,12 @@ inline hipError_t transform_impl(InputIterator input, transform_op); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - transform_kernel, - current_blocks, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + transform_kernel, + current_blocks, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("transform_kernel", current_size, start); } @@ -204,13 +204,17 @@ inline hipError_t transform(InputIterator input, constexpr bool is_pointer = std::is_pointer::value && std::is_pointer::value; - return detail::transform_impl( - input, - output, - size, - transform_op, - stream, - debug_synchronous); + using input_type = typename std::iterator_traits::value_type; + using selector = detail::transform_config_selector; + + return detail:: + transform_impl( + input, + output, + size, + transform_op, + stream, + debug_synchronous); } /// \brief Parallel device-level transform primitive for two inputs. diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_transform_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_transform_config.hpp index a74bc163fee..4327b545023 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_transform_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_transform_config.hpp @@ -39,58 +39,33 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -struct wrapped_transform_config +template +struct transform_config_selector { - static_assert(std::is_base_of::value, - "Config must be a specialization of struct template transform_config"); + using targets = std::conditional_t; + using param_type = transform_config_params; - template - struct architecture_config - { - static constexpr transform_config_params params = TransformConfig{}; - }; -}; - -template -struct wrapped_transform_config -{ - template - struct architecture_config - { - static constexpr transform_config_params params - = default_transform_pointer_config(Arch), Value>{}; - }; -}; + param_type params; -template -struct wrapped_transform_config -{ - template - struct architecture_config + template + constexpr param_type picker_helper() { - static constexpr transform_config_params params - = default_transform_config(Arch), Value>{}; - }; + // Different configs if it is a pointer. + if constexpr(IsPointer) + { + return transform_pointer_config_picker(); + } + else + { + return transform_config_picker(); + } + } + + template + constexpr transform_config_selector(Target) : params(picker_helper()) + {} }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr transform_config_params - wrapped_transform_config::architecture_config::params; - -template -template -constexpr transform_config_params - wrapped_transform_config::architecture_config::params; - -template -template -constexpr transform_config_params - wrapped_transform_config::architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // end namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp index ddcd02836ac..5fa445e8b49 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp @@ -52,15 +52,17 @@ inline hipError_t radix_sort_block_sort(KeysInputIterator keys_input, using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; - using config = wrapped_radix_sort_block_sort_config; + using Selector = radix_sort_block_sort_config_selector; - detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const kernel_config_params params = dispatch_target_arch(target_arch); + target_arch target_arch; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); sort_items_per_block = params.block_size * params.items_per_thread; const unsigned int sort_number_of_blocks = ceiling_div(size, sort_items_per_block); @@ -99,14 +101,14 @@ inline hipError_t radix_sort_block_sort(KeysInputIterator keys_input, }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan(target_arch, - radix_sort_block_sort_kernel, - dim3(sort_number_of_blocks), - dim3(params.block_size), - 0, - stream)); + execute_launch_plan( + current_target, + radix_sort_block_sort_kernel, + dim3(sort_number_of_blocks), + dim3(params.block_size), + 0, + stream)); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("radix_sort_block_sort_kernel", size, start); return hipSuccess; } diff --git a/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_merge_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_merge_sort.hpp index 35b416a9485..681b5bd3746 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_merge_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_merge_sort.hpp @@ -203,11 +203,13 @@ hipError_t radix_sort_merge_impl( // In the case that the user provides no custom config for merge sort block sort, // instead of using the autotuned merge_sort_block_sort_config, use a hard-coded config that // a power-of-two items sorted per block. - using default_block_sort_config = - typename radix_sort_block_sort_config_base::type; + constexpr auto default_block_sort_config + = radix_sort_block_sort_config_params_base(); + using radix_sort_block_sort_config = typename std::conditional, // extract the relevant config from merge_sort_block_sort_config typename Config::block_sort_config::sort_config>::type; static_assert( @@ -218,20 +220,21 @@ hipError_t radix_sort_merge_impl( using merge_sort_block_merge_config = typename std:: conditional::type; + using selector_bm = merge_sort_block_merge_config_selector; + using selector_bs = merge_sort_block_sort_config_selector; + // Wrap our radix_sort_block_sort kernel config in a merge_sort_block_sort_config // just so device_merge_sort_compile_time_verifier can check. using merge_sort_block_sort_config = merge_sort_block_sort_config; - using wrapped_bs_config - = wrapped_merge_sort_block_sort_config; - using wrapped_bm_config = wrapped_merge_sort_block_merge_config; // Some helpful checks during compile-time - device_merge_sort_compile_time_verifier(); + device_merge_sort_compile_time_verifier(); // We will get this later from the block_sort algorithm unsigned int sort_items_per_block diff --git a/projects/rocprim/rocprim/include/rocprim/thread/thread_load.hpp b/projects/rocprim/rocprim/include/rocprim/thread/thread_load.hpp index f762d6b1eef..b410b91ec75 100644 --- a/projects/rocprim/rocprim/include/rocprim/thread/thread_load.hpp +++ b/projects/rocprim/rocprim/include/rocprim/thread/thread_load.hpp @@ -85,14 +85,14 @@ T asm_thread_load(void* ptr) ROCPRIM_DEVICE ROCPRIM_INLINE type asm_thread_load(void* ptr) \ { \ interim_type retval; \ - if ROCPRIM_AMDGCN_CONSTEXPR(IS_RDNA4()) \ + if ROCPRIM_AMDGCN_CONSTEXPR(ROCPRIM_IS_RDNA4()) \ { \ asm volatile(#asm_operator " %0, %1 th:TH_DEFAULT scope:SCOPE_DEV\n\t" \ "s_wait_loadcnt_dscnt(%2)" \ : "=&v"(retval) \ : "v"(ptr), "I"(0x00)); \ } \ - else if ROCPRIM_AMDGCN_CONSTEXPR(IS_CDNA3()) \ + else if ROCPRIM_AMDGCN_CONSTEXPR(ROCPRIM_IS_CDNA3()) \ { \ asm volatile(#asm_operator " %0, %1 sc0 nt\n\t" \ "s_waitcnt(%2)" \ diff --git a/projects/rocprim/rocprim/include/rocprim/thread/thread_store.hpp b/projects/rocprim/rocprim/include/rocprim/thread/thread_store.hpp index bdc1583f886..a12d6e50ed0 100644 --- a/projects/rocprim/rocprim/include/rocprim/thread/thread_store.hpp +++ b/projects/rocprim/rocprim/include/rocprim/thread/thread_store.hpp @@ -83,14 +83,14 @@ void asm_thread_store(void* ptr, T val) type val) \ { \ interim_type temp_val = *bit_cast(&val); \ - if ROCPRIM_AMDGCN_CONSTEXPR(IS_RDNA4()) \ + if ROCPRIM_AMDGCN_CONSTEXPR(ROCPRIM_IS_RDNA4()) \ { \ asm volatile(#asm_operator " %0, %1 th:TH_DEFAULT scope:SCOPE_DEV\n\t" \ "s_wait_storecnt_dscnt(%2)" \ : \ : "v"(ptr), "v"(temp_val), "I"(0x00)); \ } \ - else if ROCPRIM_AMDGCN_CONSTEXPR(IS_CDNA3()) \ + else if ROCPRIM_AMDGCN_CONSTEXPR(ROCPRIM_IS_CDNA3()) \ { \ asm volatile(#asm_operator " %0, %1 sc0 nt\n\t" \ "s_waitcnt(%2)" \ diff --git a/projects/rocprim/rocprim/include/rocprim/warp/detail/warp_reduce_dpp.hpp b/projects/rocprim/rocprim/include/rocprim/warp/detail/warp_reduce_dpp.hpp index e478d712c4f..62bfb36aa0a 100644 --- a/projects/rocprim/rocprim/include/rocprim/warp/detail/warp_reduce_dpp.hpp +++ b/projects/rocprim/rocprim/include/rocprim/warp/detail/warp_reduce_dpp.hpp @@ -94,16 +94,21 @@ class warp_reduce_dpp } #if !ROCPRIM_TARGET_SPIRV - static_assert(VirtualWaveSize <= 32, - "VirtualWaveSize > 32 is not supported without DPP broadcasts"); -#else - if constexpr(VirtualWaveSize > 32) + if constexpr(!ROCPRIM_IS_GENERIC()) { - ROCPRIM_PRINT_ERROR_ONCE( - "VirtualWaveSize > 32 is not supported without DPP broadcasts"); - return; + static_assert(VirtualWaveSize <= 32, + "VirtualWaveSize > 32 is not supported without DPP broadcasts"); } + else #endif + { + if constexpr(VirtualWaveSize > 32) + { + ROCPRIM_PRINT_ERROR_ONCE( + "VirtualWaveSize > 32 is not supported without DPP broadcasts"); + return; + } + } } else { diff --git a/projects/rocprim/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp b/projects/rocprim/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp index 58ccedc5f1a..3d2edf287b8 100644 --- a/projects/rocprim/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp +++ b/projects/rocprim/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp @@ -88,16 +88,21 @@ class warp_scan_dpp } #if !ROCPRIM_TARGET_SPIRV - static_assert(VirtualWaveSize <= 32, - "VirtualWaveSize > 32 is not supported without DPP broadcasts"); -#else - if constexpr(VirtualWaveSize > 32) + if constexpr(!ROCPRIM_IS_GENERIC()) { - ROCPRIM_PRINT_ERROR_ONCE( - "VirtualWaveSize > 32 is not supported without DPP broadcasts"); - return; + static_assert(VirtualWaveSize <= 32, + "VirtualWaveSize > 32 is not supported without DPP broadcasts"); } + else #endif + { + if constexpr(VirtualWaveSize > 32) + { + ROCPRIM_PRINT_ERROR_ONCE( + "VirtualWaveSize > 32 is not supported without DPP broadcasts"); + return; + } + } } else { diff --git a/projects/rocprim/scripts/apply_config_improvements/apply_config_improvements.py b/projects/rocprim/scripts/apply_config_improvements/apply_config_improvements.py index 6636d862223..9c5884e300a 100644 --- a/projects/rocprim/scripts/apply_config_improvements/apply_config_improvements.py +++ b/projects/rocprim/scripts/apply_config_improvements/apply_config_improvements.py @@ -21,6 +21,7 @@ # THE SOFTWARE. import argparse +from dataclasses import dataclass import json import re import sys @@ -38,6 +39,129 @@ slower_rejected_count = 0 marginal_improvements_rejected_count = 0 +TARGET_GPUS_DICT = { + "MI350X": "mi350x", + "MI325X": "mi325x", + "MI308X": "mi308x", + "MI300A": "mi300a", + "MI300X": "mi300x", + "MI210": "mi210", + "MI100": "mi100", + "RX 9060": "rx9060", + "RX 9070": "rx9070", + "V620": "v620", + "RX 7900": "rx7900", + "RX 6900": "rx6900" +} + +def get_gen_from_architecture(arch): + match arch: + case "target_arch::gfx803": + return "gen::gcn3" + case "target_arch::gfx900" | "target_arch::gfx906": + return "gen::gcn5" + case "target_arch::gfx908": + return "gen::cdna1" + case "target_arch::gfx90a": + return "gen::cdna2" + case "target_arch::gfx942": + return "gen::cdna3" + case "target_arch::gfx950": + return "gen::cdna4" + case ( + "target_arch::gfx1010" + | "target_arch::gfx1011" + | "target_arch::gfx1012" + ): + return "gen::rdna1" + case "target_arch::gfx1030": + return "gen::rdna2" + case ( + "target_arch::gfx1100" + | "target_arch::gfx1101" + | "target_arch::gfx1102" + | "target_arch::gfx1103" + | "target_arch::gfx1150" + | "target_arch::gfx1151" + | "target_arch::gfx1152" + | "target_arch::gfx1153" + ): + return "gen::rdna3" + case "target_arch::gfx1200" | "target_arch::gfx1201": + return "gen::rdna4" + case "target_arch::unknown" | "target_arch::invalid": + return "gen::unknown" + case _: + return "gen::unknown" + + +def get_target_gpu_from_context(context): + """ + Uses the benchmark run context embedded into the benchmark output json to retrieve the targeted gpu + """ + + gpu_from_context = context['hdp_name'] + + ret = "gpu::generic" + + for gpu in TARGET_GPUS_DICT.keys(): + if gpu in gpu_from_context: + ret = f'gpu::{TARGET_GPUS_DICT[gpu]}' + + if ret == "gpu::generic": + print(f"WARNING: Unrecognized GPU '{gpu_from_context}', so using gpu::generic", file=sys.stderr, flush=True) + return ret + +@dataclass(frozen=True, eq=True) +class Target: + """ + Data class describing the target of the benchmark + """ + gen: str = "gen::unknown" + arch: str = "target_arch::unknown" + gpu: str = "gpu::generic" + rep: str = "rep::amdgcn" + + def __iter__(self): + yield from (self.gen, self.arch, self.gpu, self.rep) + + def as_str(self) -> str: + return ", ".join(str(v) for v in self) + + @classmethod + def from_string(self, raw: str) -> "Target": + """Create Target instance from a raw string with key::value pairs""" + # Split by commas and strip whitespace/newlines + parts = [p.strip() for p in raw.split(",") if p.strip()] + + # Default data + data = { + "gen": self.gen, + "arch": self.arch, + "gpu": self.gpu, + "rep": self.rep, + } + + # Assign fields based on prefixes + for part in parts: + if part.startswith("gen::"): + data["gen"] = part + elif part.startswith("target_arch::"): + data["arch"] = part + elif part.startswith("gpu::"): + data["gpu"] = part + elif part.startswith("rep::"): + data["rep"] = part + + return self(**data) + + +def get_target(context): + arch = f"target_arch::{context["hdp_gcn_arch_name"].split(":")[0]}" + gpu = get_target_gpu_from_context(context) + gen = get_gen_from_architecture(arch) + rep = "rep::amdgcn" + return Target(gen, arch, gpu, rep) class colors: OK = "\033[92m" @@ -90,7 +214,8 @@ def add_new_contenders( new_config: Dict[str, Any], new_alg_data: Dict[str, Any], score_assigner: Callable[[List[Dict[str, Any]]], None], - contenders: Dict[Tuple[Any, str], Contender], + contenders: Dict[Target, Dict[Any, Contender]], + picker_strings: Dict[Tuple[Target, str], str], improvement_threshold_percentage: float, ) -> bool: global warnings, total_specializations, improvement_count, new_specialization_count, noisy_rejected_count, slower_rejected_count, marginal_improvements_rejected_count @@ -100,23 +225,25 @@ def add_new_contenders( Z_SCORE = 1.96 # 95% confidence # Collect all rows first to determine max widths - for arch, new_arch_specializations in new_config["specializations"].items(): - if arch not in new_alg_data and ( - arch != "unknown" or "gfx908" not in new_alg_data - ): + for target, new_target_specializations in new_config["specializations"].items(): + if Target() == target: + continue + if target not in new_alg_data: warnings.append( - f"{colors.WARN}The new JSON data is missing {arch} for {algorithm_name}{colors.END_COLOR}" + f"{colors.WARN}The new JSON data is missing {target} for {algorithm_name}{colors.END_COLOR}" ) continue - # create_optimization.py its create_config_file_content() chose to make "unknown" a copy of "gfx908" - new_arch_data = new_alg_data["gfx908" if arch == "unknown" else arch] - for instance_key, new_instance_data in new_arch_specializations.items(): - if instance_key not in new_arch_data: + new_target_data = new_alg_data[target] + for instance_key, new_instance_data in list(new_target_specializations.items()): + if instance_key in {"begin_of_picker", "end_of_picker"}: + picker_strings[(target, instance_key)] = new_instance_data + continue + if instance_key not in new_target_data: sys.exit( - f"{colors.FAIL}The new JSON data is missing {arch} specialization '{stringify_instance_key(instance_key)}' for {algorithm_name}{colors.END_COLOR}" + f"{colors.FAIL}The new JSON data is missing {target} specialization '{stringify_instance_key(instance_key)}' for {algorithm_name}{colors.END_COLOR}" ) - new_instances = new_arch_data[instance_key] + new_instances = new_target_data[instance_key] add_base_args(new_instance_data["base_args"], new_instances) score_assigner(new_instances) @@ -125,15 +252,13 @@ def add_new_contenders( total_specializations += 1 row = {} - row["arch"] = str(arch) + row["target"] = target.as_str() row["key"] = stringify_instance_key(instance_key) row["new_family_index"] = new_best_instance["family_index"] row["new_bps"] = f"{new_best_instance['bytes_per_second']:.2e}" - key = (instance_key, arch) - # If there is no old config specialization, we always accept the new one. - if key not in contenders: + if target not in contenders or instance_key not in contenders[target]: status = "New" colored_status = f"{colors.OK}{status}{colors.END_COLOR}" row["status"] = colored_status @@ -147,12 +272,14 @@ def add_new_contenders( rows.append(row) new_specialization_count += 1 improved = True - contenders[key] = Contender( + if target not in contenders: + contenders[target] = {} + contenders[target][instance_key] = Contender( instance=new_best_instance, string=new_instance_data["string"] ) continue - old_best_instance = contenders[key].instance + old_best_instance = contenders[target][instance_key].instance row["old_family_index"] = old_best_instance.get("family_index", "-") row["old_bps"] = ( @@ -217,7 +344,7 @@ def add_new_contenders( rows.append(row) improvement_count += 1 improved = True - contenders[key] = Contender( + contenders[target][instance_key] = Contender( instance=new_best_instance, string=new_instance_data["string"] ) @@ -225,7 +352,7 @@ def add_new_contenders( ("status", f"Status of {algorithm_name}"), ("noise", "Noise (old/new)"), ("bps", "Bytes/sec (old/new)"), - ("arch", "Arch"), + ("target", "Target"), ("key", "Specialization"), ("family_index", "Family index (old/new)"), ] @@ -281,35 +408,40 @@ def get_old_contenders( old_config: Dict[str, Any], old_alg_data: Dict[str, Any], score_assigner: Callable[[List[Dict[str, Any]]], None], -) -> Dict[Tuple[Any, str], Contender]: +): global warnings - contenders: Dict[Tuple[Any, str], Contender] = {} + contenders: Dict[Target, Dict[Any, Contender]] = {} + picker_strings: Dict[Tuple[str, str], str] = {} old_config.setdefault("specializations", {}) - for arch, old_arch_specializations in old_config["specializations"].items(): + for target, old_target_specializations in old_config["specializations"].items(): + if Target() == target: + continue + contenders[target] = {} # Always keep every old specialization, even if missing in old_alg_data - if arch not in old_alg_data and ( - arch != "unknown" or "gfx908" not in old_alg_data - ): + if target not in old_alg_data: warnings.append( - f"{colors.WARN}The old JSON data is missing {arch} for {algorithm_name}{colors.END_COLOR}" + f"{colors.WARN}The old JSON data is missing {target} for {algorithm_name}{colors.END_COLOR}" ) - old_arch_data = {} + old_target_data = {} else: - old_arch_data = old_alg_data["gfx908" if arch == "unknown" else arch] + old_target_data = old_alg_data[target] - for instance_key, old_instance_data in old_arch_specializations.items(): - if instance_key not in old_arch_data: - # If old_arch_data is falsy, then a warning was already printed - if old_arch_data: + for instance_key, old_instance_data in old_target_specializations.items(): + if instance_key in {"begin_of_picker", "end_of_picker"}: + picker_strings[(target, instance_key)] = old_instance_data + continue + if instance_key not in old_target_data: + # If old_target_data is falsy, then a warning was already printed + if old_target_data: warnings.append( - f"{colors.WARN}The old JSON data is missing {arch} specialization '{stringify_instance_key(instance_key)}' for {algorithm_name}{colors.END_COLOR}" + f"{colors.WARN}The old JSON data is missing {target} specialization '{stringify_instance_key(instance_key)}' for {algorithm_name}{colors.END_COLOR}" ) old_instances = [] else: - old_instances = old_arch_data[instance_key] + old_instances = old_target_data[instance_key] add_base_args(old_instance_data["base_args"], old_instances) score_assigner(old_instances) @@ -319,17 +451,38 @@ def get_old_contenders( get_best_instance(old_instances) if old_instances else {} ) - key = (instance_key, arch) - contenders[key] = Contender( + contenders[target][instance_key] = Contender( instance=old_best_instance, string=old_instance_data["string"] ) - return contenders + return (contenders, picker_strings) + + +def get_comp_targets(old_config: Dict[str, Any], new_config: Dict[str, Any], algorithm_name: str): + ret = f"// All existing configs\nusing {algorithm_name}_targets = comp_targets<" + unique_targets = [] + for target in new_config["specializations"].keys(): + if target != Target() and target not in unique_targets: + unique_targets.append(target) + for target in old_config["specializations"].keys(): + if target != Target() and target not in unique_targets: + unique_targets.append(target) + + for unique_target in unique_targets: + ret += f"comp_target<{unique_target.as_str()}>," + + ret += f"comp_target<{Target().as_str()}>>;" + return ret def get_best_instance(instances): return max(instances, key=lambda x: x["score"]) +def parse_args(base_args): + numbers = [] + for item in base_args: + numbers.extend(map(int, re.findall(r'\d+', item))) + return numbers def add_base_args(base_args, instances): """ @@ -337,8 +490,9 @@ def add_base_args(base_args, instances): """ for instance in instances: if instance["algo"] in {"merge_sort_block_sort", "radix_sort_block_sort"}: - instance["bs"] = int(base_args[0]) - instance["ipt"] = int(base_args[1]) + args = parse_args(base_args) + instance["bs"] = int(args[0]) + instance["ipt"] = int(args[1]) def score_assigner_default(instances: List[Dict[str, Any]]) -> None: @@ -399,7 +553,7 @@ def generate_improved_configs( new_alg_data = new_data.get(algorithm_name, {}) score_assigner = get_score_assigner(algorithm_name) - contenders = get_old_contenders( + contenders, picker_strings = get_old_contenders( algorithm_name, old_config or {}, old_alg_data or {}, score_assigner ) improved = add_new_contenders( @@ -408,6 +562,7 @@ def generate_improved_configs( new_alg_data, score_assigner, contenders, + picker_strings, improvement_threshold_percentage, ) if not improved: @@ -420,14 +575,32 @@ def generate_improved_configs( if old_config else new_config["start_of_config"] ) - for contender in contenders.values(): - f.write(contender.string) + # Add every picker for every different target + for target, target_contenders in contenders.items(): + f.write(picker_strings[(target, "begin_of_picker")]) + for contender in target_contenders.values(): + f.write(contender.string) + f.write(picker_strings[(target, "end_of_picker")]) + + # Add unknown target fallback case + unknown_target_contender = ( + old_config["specializations"][Target()]["begin_of_picker"] + if old_config + else new_config["specializations"][Target()]["begin_of_picker"] + ).lstrip("\n") + + f.write(unknown_target_contender) + + # Add comp_targets + comp_targets = get_comp_targets(old_config, new_config, algorithm_name) + f.write(comp_targets) + # Remove leading newlines from end_of_config to avoid triple newlines end = ( old_config["end_of_config"] if old_config else new_config["end_of_config"] - ).lstrip("\n") + ) if not end.endswith("\n"): end += "\n" f.write(end) @@ -437,24 +610,30 @@ def extract_template_arguments(code: str) -> list[str]: """ Extracts the top-level template arguments from a base config class in a C++ struct specialization. - Assumes the input contains a line ending in `: XXX_config<...> {`, where `XXX_config` is - the name of the base config class and the angle brackets contain template arguments. + Assumes the input contains a line ending in `: XXX_config_params{...} {`, where `XXX_config_params` is + the name of the base config params class and the curly brackets contain template arguments. Args: code (str): A string containing the C++ struct definition. Returns: List[str]: A list of top-level template arguments as strings. - Nested templates like `kernel_config<1024, 1>` are preserved as single items. + Nested structs like `kernel_config_params{1024, 1}` are preserved as single items. Example: Input: - 'struct foo : some_config<256, kernel_config<128, 2>, 1 << 2> {' + ``` + if constexpr(true) + { + return some_config_params{256, kernel_config_params{128, 2}, 1 << 2}; + } + ``` Output: - ['256', 'kernel_config<128, 2>', '1 << 2'] + ['256', 'kernel_config_params{128, 2}', '1 << 2'] """ # Match the base class ending in `_config<...> {` - match = re.search(r":\s*\w+_config<(.+?)>\s*{", code, re.DOTALL) + match = re.search(r"return\s+\w+_config_params\s*\{\s*([\s\S]*?)\s*\};", code, re.DOTALL) + if not match: return [] @@ -466,15 +645,15 @@ def extract_template_arguments(code: str) -> list[str]: depth = 0 i = 0 while i < len(template_args): - if template_args[i : i + 2] in ("<<", ">>"): + if template_args[i : i + 2] in ("{{", "}}"): current += template_args[i : i + 2] i += 2 continue char = template_args[i] - if char == "<": + if char == "{": depth += 1 - elif char == ">": + elif char == "}": depth -= 1 if char == "," and depth == 0: @@ -501,21 +680,24 @@ def get_specialization_key(specialization_string: str) -> Any: # 'select_flag': { # 'full_name': 'device_select_flag.hpp', # 'start_of_config': '// Copyright ...', +# 'comp_targets': 'using select_flag_targets = comp_targets< ... >;' # 'end_of_config': '} // end namespace detail ...', # 'specializations': { -# 'gfx1200': { +# Target(gen::rdna2, target_arch::gfx1030, gpu::v620, rep::amdgcn): { +# 'begin_of_picker': 'template constexpr auto select_flag_config_picker() ... {\n', +# 'end_of_picker': `// Default case if none of the conditions match\nreturn partition_config_params_base();}\n`, # Instance(key_type='double'): { -# 'string': '// Based on key_type = double\n ... select_config<512, kernel_config<128, 2>>\n{};\n\n', +# 'string': '// Based on key_type = double\n ... partition_config_params{512, kernel_config_params{128, 2}};\n', # 'base_args': [ # '512', -# 'kernel_config<128, 2>' +# 'kernel_config_params{128, 2}' # ] # }, # Instance(key_type='float'): { -# 'string': '// Based on key_type = float\n ... select_config<1024, kernel_config<128, 2>>\n{};\n\n', +# 'string': '// Based on key_type = float\n ... partition_config_params{1024, kernel_config_params{128, 2}};\n', # 'base_args': [ # '1024', -# 'kernel_config<128, 2>' +# 'kernel_config_params{128, 2}' # ] # } # } @@ -531,32 +713,55 @@ def read_configs(dir_path: Path) -> Dict[str, Any]: config["full_name"] = hpp_path.name text = hpp_path.read_text() - matches = re.split(r"(// Based on .*?{\s*};\n\n)", text, flags=re.DOTALL) - matches = list(filter(None, matches)) - config["start_of_config"] = matches[0] - config["end_of_config"] = matches[-1] + pickers = re.split(r"(template\s*<[\s\S]*?>\s*\{[\s\S]*?}\n\n)", text, flags=re.DOTALL) + pickers = list(filter(None, pickers)) + # Remove spaces from picker functions to have a more consistent input. + pickers = [picker for picker in pickers if picker.strip() != ""] + + # All the code before the picker functions. + config["start_of_config"] = pickers[0] + + # The code after the picker functions need to be divided in comp_targets and the end. + comp_target = re.split(r"(// .*\nusing[\s\S]*?;)", pickers[-1], flags=re.DOTALL) + comp_target = list(filter(None, comp_target)) + config["comp_targets"] = comp_target[0] + config["end_of_config"] = comp_target[-1] all_specializations = config.setdefault("specializations", {}) - specialization_strings = matches[1:-1] - for specialization_string in specialization_strings: - arch_match = re.search(r"target_arch::(.*?)\)", specialization_string) - if not arch_match: - sys.exit( - f"{colors.FAIL}Could not find arch in specialization: {specialization_string}{colors.END_COLOR}" - ) - arch = arch_match.group(1) + picker_strings = pickers[1:-1] + for picker_string in picker_strings: + target_match = re.search(r"comp_target<((.||\s)*?)>", picker_string) - specialization_key = get_specialization_key(specialization_string) - arch_specializations = all_specializations.setdefault(arch, {}) - if specialization_key in arch_specializations: + if not target_match: sys.exit( - f"{colors.FAIL}Specialization key duplicate '{specialization_key}'{colors.END_COLOR}" + f"{colors.FAIL}Could not find arch in specialization: {picker_string}{colors.END_COLOR}" ) - arch_specializations[specialization_key] = { - "string": specialization_string, - "base_args": extract_template_arguments(specialization_string), - } + + target = Target.from_string(target_match.group(1)) + + specializations = re.split(r"(\s*// Based on[\s\S]*?\}\n(?=\s*//))", picker_string, flags=re.DOTALL) + specializations = list(filter(None, specializations)) + target_specializations = all_specializations.setdefault(target, {}) + target_specializations["begin_of_picker"] = specializations[0] + target_specializations["end_of_picker"] = specializations[-1] + + for specialization_string in specializations[1:-1]: + specialization_key = get_specialization_key(specialization_string) + + if specialization_key in target_specializations: + sys.exit( + f"{colors.FAIL}Specialization key duplicate '{specialization_key}'{colors.END_COLOR}" + ) + args = extract_template_arguments(specialization_string) + if args is []: + sys.exit( + f"{colors.FAIL}Arguments are empty '{specialization_key}'{colors.END_COLOR}" + ) + target_specializations[specialization_key] = { + "string": specialization_string, + "base_args": args, + } return configs @@ -575,7 +780,7 @@ def get_instance_key(instanced_types: Dict[str, Any], selectors: List[str]) -> A # This is the format of the returned dictionary: # { # 'select_flag': { -# 'gfx1200': { +# Target(gen::rdna2, target_arch::gfx1030, gpu::v620, rep::amdgcn): { # Instance(key_type='double'): [ # { 'items_per_second': 200, 'segment_count': 10 }, # { 'items_per_second': 300, 'segment_count': 20 } @@ -592,7 +797,7 @@ def read_data(dir_path: Path, selectors: Dict[str, List[str]]) -> Dict[str, Any] for json_path in dir_path.rglob("*.json"): json_data = json.loads(json_path.read_text()) - arch = json_data["context"]["hdp_gcn_arch_name"].split(":")[0] + target = get_target(json_data["context"]) for benchmark in json_data["benchmarks"]: name = benchmark["name"] @@ -614,14 +819,14 @@ def read_data(dir_path: Path, selectors: Dict[str, List[str]]) -> Dict[str, Any] algorithm_name += "_" + data["subalgo"] alg_data = all_data.setdefault(algorithm_name, {}) - arch_data = alg_data.setdefault(arch, defaultdict(list)) + target_data = alg_data.setdefault(target, defaultdict(list)) if algorithm_name not in selectors: sys.exit( f"{colors.FAIL}No selectors found for algorithm '{algorithm_name}' in the selectors JSON{colors.END_COLOR}" ) instance_key = get_instance_key(data, selectors[algorithm_name]) - arch_data[instance_key].append(data) + target_data[instance_key].append(data) return all_data @@ -706,21 +911,23 @@ def add_arguments(parser: StrictArgumentParser) -> None: class TestExtractTemplateArguments(unittest.TestCase): def test_complex(self): specialization = """ - struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, (1 << 17) + 1 >> 2 + 70000, - 8, - block_radix_rank_algorithm::match> - {}; + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + (1 << 17) + 1 >> 2 + 70000, + 8, + block_radix_rank_algorithm::match + }; + } + // Needs a comment after it. """ expected = [ - "kernel_config<1024, 1>", + "kernel_config_params{1024, 1}", "(1 << 17) + 1 >> 2 + 70000", "8", "block_radix_rank_algorithm::match", @@ -753,6 +960,11 @@ def main() -> None: args.improved_configs_dir, ) + if total_specializations == 0: + warnings.append( + f"{colors.WARN}No specializations are found in the config files!{colors.END_COLOR}" + ) + if warnings: print("\n".join(warnings)) print("") diff --git a/projects/rocprim/scripts/autotune/create_optimization.py b/projects/rocprim/scripts/autotune/create_optimization.py index 95b58e6dbb1..b7ff6b07ff4 100755 --- a/projects/rocprim/scripts/autotune/create_optimization.py +++ b/projects/rocprim/scripts/autotune/create_optimization.py @@ -41,7 +41,60 @@ from typing import Dict, List, Callable, Optional, Tuple from jinja2 import Environment, PackageLoader, select_autoescape -TARGET_ARCHITECTURES = ['gfx803', 'gfx900', 'gfx906', 'gfx908', 'gfx90a', 'gfx942', 'gfx1030', 'gfx1100', 'gfx1102', 'gfx1201'] +TARGET_ARCHITECTURES = [ + "gfx803", + "gfx900", + "gfx906", + "gfx908", + "gfx90a", + "gfx942", + "gfx950", + "gfx1010", + "gfx1011", + "gfx1012", + "gfx1030", + "gfx1100", + "gfx1101", + "gfx1102", + "gfx1103", + "gfx1150", + "gfx1151", + "gfx1152", + "gfx1153", + "gfx1200", + "gfx1201", +] + +TARGET_GPUS_DICT = { + "MI350X": "mi350x", + "MI325X": "mi325x", + "MI308X": "mi308x", + "MI300A": "mi300a", + "MI300X": "mi300x", + "MI210": "mi210", + "MI100": "mi100", + "RX 9060": "rx9060", + "RX 9070": "rx9070", + "V620": "v620", + "RX 7900": "rx7900", + "RX 6900": "rx6900" +} +@dataclass(frozen=True, eq=True) +class Target: + """ + Data class describing the target of the benchmark + """ + gen: str = "gen::unknown" + arch: str = "target_arch::unknown" + gpu: str = "gpu::generic" + rep: str = "rep::amdgcn" + + def __iter__(self): + yield from (self.gen, self.arch, self.gpu, self.rep) + + def as_str(self) -> str: + return ", ".join(str(v) for v in self) + # C++ typename used for optional types EMPTY_TYPENAME = "empty_type" @@ -135,24 +188,28 @@ def translate_settings_to_cpp_metaprogramming( setting_list.append(f"(!std::is_same<{typename}, rocprim::{EMPTY_TYPENAME}>::value)") for name, value in const_configuration.items(): setting_list.append(f"({name} == {value})") - return "std::enable_if_t<(" + " && ".join(setting_list) + ")>" + return "if constexpr(" + " && ".join(setting_list) + ")" -class BenchmarksOfArchitecture: +class BenchmarksOfTarget: """ Stores the benchmark results for a specific architecture and algorithm. """ - def __init__(self, arch_name: str, config_selection_params, fallback_entries: List[FallbackCase], config_get_best, algorithm_name): - self.config_selection_params = config_selection_params - self.fallback_entries: List[FallbackCase] = fallback_entries - self.arch_name: str = arch_name - self.config_get_best: Callable[[Dict], Dict[str, str]] = config_get_best - self.algorithm_name: str = algorithm_name - # Dictionary storing the benchmarks - # Key is an instantiation of the configuration selection types - # Value is a list of all benchmark runs corresponding to that instantiation, - # these benchmarks in this list vary in the actual configuration used to run the benchmark - self.benchmarks = defaultdict(list) + def __init__( + self, + target: Target, + config_selection_params=None, + fallback_entries: Optional[List["FallbackCase"]] = None, + config_get_best: Optional[Callable[[Dict], Dict[str, str]]] = None, + algorithm_name: Optional[str] = None, + ): + self.__target: Target = target + self.config_selection_params = config_selection_params + self.fallback_entries: List["FallbackCase"] = fallback_entries or [] + self.config_get_best = config_get_best + self.algorithm_name = algorithm_name + self.algorithm_name_short = algorithm_name.replace("device_", "") if algorithm_name != None else None + self.benchmarks = defaultdict(list) def __get_instance_key(self, instanced_types): """ @@ -177,8 +234,11 @@ def add_measurement(self, benchmark_data: Dict[str, str]): self.benchmarks[instance_key].append(benchmark_data) @property - def name(self) -> str: - return self.arch_name + def target(self) -> Target: + return self.__target + + def get_enable_if(self, config_params) -> str: + return f'std::enable_if_t>::value, {config_params}>' def __get_best_benchmark(self, instance_key) -> Dict[str, str]: """ @@ -348,31 +408,38 @@ class Algorithm: """ def __init__(self, fallback_entries: List[FallbackCase], config_get_best = default_config_get_best): - self.architectures: Dict[str, BenchmarksOfArchitecture] = {} + self.targets: Dict[Target, BenchmarksOfTarget] = {} self.fallback_entries: List[FallbackCase] = fallback_entries self.config_get_best = config_get_best - def add_measurement(self, single_benchmark_data: Dict[str, str], architecture: str): + def __get_fallback_target(self) -> BenchmarksOfTarget: + """ + Returns M100 or first gpu in the list + """ + for target in self.targets: + if target.gpu == "mi100": + return self.targets[target] + return next(iter(self.targets.values())) + + + def add_measurement(self, single_benchmark_data: Dict[str, str], target: Target): """ Adds a single benchmark execution for a given architecture """ - if architecture not in self.architectures: - self.architectures[architecture] = BenchmarksOfArchitecture(architecture, self.config_selection_params, + if target not in self.targets: + self.targets[target] = BenchmarksOfTarget(target, self.config_selection_params, self.fallback_entries, self.config_get_best, self.algorithm_name) - self.architectures[architecture].add_measurement(single_benchmark_data) + self.targets[target].add_measurement(single_benchmark_data) def create_config_file_content(self) -> str: """ Generate the content of the configuration file, including license and header guards, based on general template file. """ - if 'target_arch::gfx908' in self.architectures: - self.architectures['target_arch::unknown'] = copy.deepcopy(self.architectures['target_arch::gfx908']) - self.architectures['target_arch::unknown'].arch_name = 'target_arch::unknown' - algorithm_template = env.get_template(self.cpp_configuration_template_name) - rendered_template = algorithm_template.render(all_architectures=self.architectures.values()) + fallback_target = self.__get_fallback_target() + rendered_template = algorithm_template.render(all_targets=self.targets.values(), fallback_target=fallback_target, unknown_target= BenchmarksOfTarget(Target())) return rendered_template @@ -808,11 +875,87 @@ def __get_target_architecture_from_context(self, benchmark_run): Uses the benchmark run context embedded into the benchmark output json to retrieve the targeted architecture """ - name_from_context = benchmark_run['context']['hdp_gcn_arch_name'].split(":")[0] - if name_from_context in TARGET_ARCHITECTURES: - return f'target_arch::{name_from_context}' + arch_from_context = benchmark_run['context']['hdp_gcn_arch_name'].split(":")[0] + if arch_from_context in TARGET_ARCHITECTURES: + return f'target_arch::{arch_from_context}' else: - raise RuntimeError(f"ERROR: unknown hdp_gcn_arch_name: {name_from_context}") + raise RuntimeError(f"ERROR: unknown hdp_gcn_arch_name: {arch_from_context}") + + def __get_target_gpu_from_context(self, benchmark_run): + """ + Uses the benchmark run context embedded into the benchmark output json to retrieve the targeted gpu + """ + + gpu_from_context = benchmark_run['context']['hdp_name'] + + ret = "gpu::generic" + + for gpu in TARGET_GPUS_DICT.keys(): + if gpu in gpu_from_context: + ret = f'gpu::{TARGET_GPUS_DICT[gpu]}' + + if ret == "gpu::generic": + print("WARNING: Could find gpu in defined gpus will use gpu::generic", file=sys.stderr, flush=True) + return ret + + def __get_target_rep_from_context(self, benchmark_run): + """ + Uses the benchmark run context embedded into the benchmark output json to retrieve the targeted rep + TODO The data is not yet inbedded in the benchmark output json + """ + return "rep::amdgcn" + + def __get_gen_from_architecture(self, arch): + match arch: + case "target_arch::gfx803": + return "gen::gcn3" + case "target_arch::gfx900" | "target_arch::gfx906": + return "gen::gcn5" + case "target_arch::gfx908": + return "gen::cdna1" + case "target_arch::gfx90a": + return "gen::cdna2" + case "target_arch::gfx942": + return "gen::cdna3" + case "target_arch::gfx950": + return "gen::cdna4" + case ( + "target_arch::gfx1010" + | "target_arch::gfx1011" + | "target_arch::gfx1012" + ): + return "gen::rdna1" + case "target_arch::gfx1030": + return "gen::rdna2" + case ( + "target_arch::gfx1100" + | "target_arch::gfx1101" + | "target_arch::gfx1102" + | "target_arch::gfx1103" + | "target_arch::gfx1150" + | "target_arch::gfx1151" + | "target_arch::gfx1152" + | "target_arch::gfx1153" + ): + return "gen::rdna3" + case "target_arch::gfx1200" | "target_arch::gfx1201": + return "gen::rdna4" + case "target_arch::unknown" | "target_arch::invalid": + return "gen::unknown" + case _: + return "gen::unknown" + + def __get_target(self, benchmark_run) -> Target: + """ + Uses the benchmark run context embedded into the benchmark output json to retrieve the target + """ + arch = self.__get_target_architecture_from_context(benchmark_run) + gpu = self.__get_target_gpu_from_context(benchmark_run) + rep = self.__get_target_rep_from_context(benchmark_run) + gen = self.__get_gen_from_architecture(arch) + + return Target(gen, arch, gpu, rep) + def __get_single_benchmark(self, single_benchmark): """ @@ -828,7 +971,7 @@ def __get_single_benchmark(self, single_benchmark): raise RuntimeError(f"ERROR: cannot parse JSON from: \"{single_benchmark['name']}\"") return dict(single_benchmark, **tokenized_name) - def __add_benchmark_to_algorithm(self, single_benchmark, arch): + def __add_benchmark_to_algorithm(self, single_benchmark, target): """ Adds a single_benchmark execution of a given Algorithm on a given architecture, to the Algorithm object @@ -839,7 +982,7 @@ def __add_benchmark_to_algorithm(self, single_benchmark, arch): algorithm_name += "_" + single_benchmark['subalgo'] if algorithm_name not in self.algorithms: self.algorithms[algorithm_name] = create_algorithm(algorithm_name, self.fallback_entries) - self.algorithms[algorithm_name].add_measurement(single_benchmark, arch) + self.algorithms[algorithm_name].add_measurement(single_benchmark, target) def add_run(self, benchmark_run_file_path: str): """ @@ -852,10 +995,11 @@ def add_run(self, benchmark_run_file_path: str): try: print(f'INFO: Processing "{benchmark_run_file_path}"') - arch = self.__get_target_architecture_from_context(benchmark_run_data) + target = self.__get_target(benchmark_run_data) + for raw_single_benchmark in benchmark_run_data['benchmarks']: single_benchmark = self.__get_single_benchmark(raw_single_benchmark) - self.__add_benchmark_to_algorithm(single_benchmark, arch) + self.__add_benchmark_to_algorithm(single_benchmark, target) print(f'INFO: Successfully processed file "{benchmark_run_file_path}"') except NotSupportedError as error: print(f'WARNING: Could not process file "{benchmark_run_file_path}": {error}', file=sys.stderr, flush=True) diff --git a/projects/rocprim/scripts/autotune/templates/adjacent_difference_config_template b/projects/rocprim/scripts/autotune/templates/adjacent_difference_config_template index 353e86d1fae..02c62d1d546 100644 --- a/projects/rocprim/scripts/autotune/templates/adjacent_difference_config_template +++ b/projects/rocprim/scripts/autotune/templates/adjacent_difference_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_ADJACENT_DIFFERENCE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -adjacent_difference_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +adjacent_difference_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_adjacent_difference_config : default_adjacent_difference_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto adjacent_difference_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_adjacent_difference_config({{ benchmark_of_architecture.name }}), value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("adjacent_difference_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return adjacent_difference_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return adjacent_difference_config_picker, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/adjacent_difference_inplace_config_template b/projects/rocprim/scripts/autotune/templates/adjacent_difference_inplace_config_template index 3e5d7fa5cdb..6cfb8a99aef 100644 --- a/projects/rocprim/scripts/autotune/templates/adjacent_difference_inplace_config_template +++ b/projects/rocprim/scripts/autotune/templates/adjacent_difference_inplace_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_ADJACENT_DIFFERENCE_INPLACE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -adjacent_difference_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +adjacent_difference_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_adjacent_difference_inplace_config : default_adjacent_difference_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto adjacent_difference_inplace_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_adjacent_difference_inplace_config({{ benchmark_of_architecture.name }}), value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("adjacent_difference_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return adjacent_difference_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return adjacent_difference_inplace_config_picker, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/adjacent_find_config_template b/projects/rocprim/scripts/autotune/templates/adjacent_find_config_template index 8ef25c83eb1..9bd7e9e1992 100644 --- a/projects/rocprim/scripts/autotune/templates/adjacent_find_config_template +++ b/projects/rocprim/scripts/autotune/templates/adjacent_find_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_ADJACENT_FIND_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -adjacent_find_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +adjacent_find_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_adjacent_find_config : default_adjacent_find_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto adjacent_find_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_adjacent_find_config({{ benchmark_of_architecture.name }}), input_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("adjacent_find_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return adjacent_find_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return adjacent_find_config_picker, input_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/binary_search_config_template b/projects/rocprim/scripts/autotune/templates/binary_search_config_template index 00b19ca8380..d3a00c773ff 100644 --- a/projects/rocprim/scripts/autotune/templates/binary_search_config_template +++ b/projects/rocprim/scripts/autotune/templates/binary_search_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_BINARY_SEARCH_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -binary_search_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +transform_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_binary_search_config : default_binary_search_config_base -{}; +{% macro config_picker() -%} +template constexpr auto binary_search_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_binary_search_config({{ benchmark_of_architecture.name }}), value_type, output_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("transform_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return binary_search_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return binary_search_config_picker, value_type, output_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/config_template b/projects/rocprim/scripts/autotune/templates/config_template index 9d44e5cdcbd..1b8112c749b 100644 --- a/projects/rocprim/scripts/autotune/templates/config_template +++ b/projects/rocprim/scripts/autotune/templates/config_template @@ -41,15 +41,34 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -{{ general_case() }} - -{% for benchmark_of_architecture in all_architectures %} - {% for based_on_type, fallback_selection_criteria, measurement in benchmark_of_architecture.fallback_types %} -{{ configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) }} -{{ kernel_configuration(measurement) }} +{% set ns = namespace(algorithm_name='') %} +{% for benchmark_of_target in all_targets %} +{% set ns.algorithm_name = benchmark_of_target.algorithm_name_short %} +{{ config_picker() }} -> {{ enable_if(benchmark_of_target) }} +{ + {% for based_on_type, fallback_selection_criteria, measurement in benchmark_of_target.fallback_types %} +// Based on {{ based_on_type }} +{{ fallback_selection_criteria }} +{ return {{ kernel_configuration(measurement) }}; } {% endfor %} +// Default case if none of the conditions match +{{ default_case() }} +} + +{% endfor %} + +{{ config_picker() }} -> {{ enable_if(unknown_target) }} +{ +{{fallback_config(fallback_target)}} +} + +// All existing configs +using {{ns.algorithm_name}}_targets = comp_targets< +{% for benchmark_of_target in all_targets %} +comp_target<{{benchmark_of_target.target.as_str()}}>, {% endfor %} +comp_target<{{unknown_target.target.as_str()}}>>; } // end namespace detail diff --git a/projects/rocprim/scripts/autotune/templates/find_first_of_config_template b/projects/rocprim/scripts/autotune/templates/find_first_of_config_template index 84e7fb476aa..9f3aa58b36d 100644 --- a/projects/rocprim/scripts/autotune/templates/find_first_of_config_template +++ b/projects/rocprim/scripts/autotune/templates/find_first_of_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_FIND_FIRST_OF_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -find_first_of_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +find_first_of_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_find_first_of_config : default_find_first_of_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto find_first_of_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_find_first_of_config({{ benchmark_of_architecture.name }}), value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("find_first_of_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return find_first_of_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return find_first_of_config_picker, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/histogram_config_template b/projects/rocprim/scripts/autotune/templates/histogram_config_template index 810d2b84fa6..0f82adb7dcc 100644 --- a/projects/rocprim/scripts/autotune/templates/histogram_config_template +++ b/projects/rocprim/scripts/autotune/templates/histogram_config_template @@ -5,17 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_HISTOGRAM_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -histogram_config, {{ measurement['cfg']['max_grid_size'] }}, {{ measurement['cfg']['shared_impl_max_bins'] }}, {{ measurement['cfg']['shared_impl_histograms'] }}, kernel_config<{{ measurement['cfg']['global_hist_bs'] }}, {{ measurement['cfg']['global_hist_ipt'] }}>> { }; +histogram_config_params{kernel_config_params{ {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, {{ measurement['cfg']['max_grid_size'] }}, {{ measurement['cfg']['shared_impl_max_bins'] }}, {{ measurement['cfg']['shared_impl_histograms'] }}, kernel_config_params{ {{ measurement['cfg']['global_hist_bs'] }}, {{ measurement['cfg']['global_hist_ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_histogram_config : -default_histogram_config_base::type { }; +{% macro config_picker() -%} +template constexpr auto histogram_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template -struct default_histogram_config({{ benchmark_of_architecture.name }}), value_type, channels, active_channels, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("histogram_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return histogram_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return histogram_config_picker, value_type, channels, active_channels>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/lower_bound_config_template b/projects/rocprim/scripts/autotune/templates/lower_bound_config_template index f2d49ff118a..445fbfbcc9b 100644 --- a/projects/rocprim/scripts/autotune/templates/lower_bound_config_template +++ b/projects/rocprim/scripts/autotune/templates/lower_bound_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_LOWER_BOUND_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -lower_bound_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +transform_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_lower_bound_config : default_binary_search_config_base -{}; +{% macro config_picker() -%} +template constexpr auto lower_bound_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_lower_bound_config({{ benchmark_of_architecture.name }}), value_type, output_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("transform_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return binary_search_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return lower_bound_config_picker, value_type, output_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/merge_config_template b/projects/rocprim/scripts/autotune/templates/merge_config_template index e89592f68ed..7234cca5260 100644 --- a/projects/rocprim/scripts/autotune/templates/merge_config_template +++ b/projects/rocprim/scripts/autotune/templates/merge_config_template @@ -5,21 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_MERGE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -merge_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +merge_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_merge_config : default_merge_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto merge_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template -struct default_merge_config< - static_cast({{ benchmark_of_architecture.name }}), - key_type, - value_type, - {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("merge_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return merge_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return merge_config_picker, key_type, value_type>(); {%- endmacro %} \ No newline at end of file diff --git a/projects/rocprim/scripts/autotune/templates/mergesort_block_merge_config_template b/projects/rocprim/scripts/autotune/templates/mergesort_block_merge_config_template index 06e4236562f..f0c39409606 100644 --- a/projects/rocprim/scripts/autotune/templates/mergesort_block_merge_config_template +++ b/projects/rocprim/scripts/autotune/templates/mergesort_block_merge_config_template @@ -5,15 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_MERGE_SORT_BLOCK_MERGE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, {{ measurement['cfg']['mergepath_partition_bs'] }}, {{ measurement['cfg']['mergepath_bs'] }}, {{ measurement['cfg']['mergepath_ipt'] }}> { }; +merge_sort_block_merge_config_params{ { 256, 1, (1 << 17) + 70000 }, { {{ measurement['cfg']['mergepath_partition_bs'] }}, 1}, { {{ measurement['cfg']['mergepath_bs'] }}, {{ measurement['cfg']['mergepath_ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template struct default_merge_sort_block_merge_config : -merge_sort_block_merge_config_base::type {}; +{% macro config_picker() -%} +template constexpr auto merge_sort_block_merge_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_merge_sort_block_merge_config({{ benchmark_of_architecture.name }}), key_type, value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("merge_sort_block_merge_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return merge_sort_block_merge_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return merge_sort_block_merge_config_picker, key_type, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/mergesort_block_sort_config_template b/projects/rocprim/scripts/autotune/templates/mergesort_block_sort_config_template index f3931f2c780..84666b0cf5a 100644 --- a/projects/rocprim/scripts/autotune/templates/mergesort_block_sort_config_template +++ b/projects/rocprim/scripts/autotune/templates/mergesort_block_sort_config_template @@ -5,19 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_MERGE_SORT_BLOCK_SORT_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -merge_sort_block_sort_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +merge_sort_block_sort_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template struct default_merge_sort_block_sort_config : -merge_sort_block_sort_config_base::type {}; +{% macro config_picker() -%} +template constexpr auto merge_sort_block_sort_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_merge_sort_block_sort_config({{ benchmark_of_architecture.name }}), key_type, value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("merge_sort_block_sort_config_params") }} {%- endmacro %} +{% macro default_case() -%} +return merge_sort_block_sort_config_params_base(); +{%- endmacro %} - - +{% macro fallback_config(fallback_target) -%} + return merge_sort_block_sort_config_picker, key_type, value_type>(); +{%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/partition_flag_config_template b/projects/rocprim/scripts/autotune/templates/partition_flag_config_template index 4495e5de7a8..e67f0541f04 100644 --- a/projects/rocprim/scripts/autotune/templates/partition_flag_config_template +++ b/projects/rocprim/scripts/autotune/templates/partition_flag_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_PARTITION_FLAG_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_partition_flag_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto partition_flag_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_partition_flag_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return partition_flag_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/partition_predicate_config_template b/projects/rocprim/scripts/autotune/templates/partition_predicate_config_template index 021a343ae3d..95892f528a8 100644 --- a/projects/rocprim/scripts/autotune/templates/partition_predicate_config_template +++ b/projects/rocprim/scripts/autotune/templates/partition_predicate_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_PARTITION_PREDICATE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_partition_predicate_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto partition_predicate_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_partition_predicate_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return partition_predicate_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/partition_three_way_config_template b/projects/rocprim/scripts/autotune/templates/partition_three_way_config_template index 01f06bc8b1f..7929fdfb30c 100644 --- a/projects/rocprim/scripts/autotune/templates/partition_three_way_config_template +++ b/projects/rocprim/scripts/autotune/templates/partition_three_way_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_PARTITION_THREE_WAY_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_partition_three_way_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto partition_three_way_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_partition_three_way_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return partition_three_way_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/partition_two_way_flag_config_template b/projects/rocprim/scripts/autotune/templates/partition_two_way_flag_config_template index 16826143e61..0eca0b9da8e 100644 --- a/projects/rocprim/scripts/autotune/templates/partition_two_way_flag_config_template +++ b/projects/rocprim/scripts/autotune/templates/partition_two_way_flag_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_PARTITION_TWO_WAY_FLAG_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_partition_two_way_flag_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto partition_two_way_flag_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_partition_two_way_flag_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return partition_two_way_flag_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/partition_two_way_predicate_config_template b/projects/rocprim/scripts/autotune/templates/partition_two_way_predicate_config_template index ac11f1459de..a4c9a927745 100644 --- a/projects/rocprim/scripts/autotune/templates/partition_two_way_predicate_config_template +++ b/projects/rocprim/scripts/autotune/templates/partition_two_way_predicate_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_PARTITION_TWO_WAY_PREDICATE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_partition_two_way_predicate_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto partition_two_way_predicate_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_partition_two_way_predicate_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return partition_two_way_predicate_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/radixsort_block_sort_config_template b/projects/rocprim/scripts/autotune/templates/radixsort_block_sort_config_template index b27a6aa30d8..5f664fee5f7 100644 --- a/projects/rocprim/scripts/autotune/templates/radixsort_block_sort_config_template +++ b/projects/rocprim/scripts/autotune/templates/radixsort_block_sort_config_template @@ -5,19 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_RADIX_SORT_BLOCK_SORT_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -kernel_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +kernel_config_params{ {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } {%- endmacro %} -{% macro general_case() -%} -template struct default_radix_sort_block_sort_config : -radix_sort_block_sort_config_base::type { }; +{% macro config_picker() -%} +template constexpr auto radix_sort_block_sort_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_radix_sort_block_sort_config({{ benchmark_of_architecture.name }}), key_type, value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("kernel_config_params") }} {%- endmacro %} - - +{% macro default_case() -%} +return radix_sort_block_sort_config_params_base(); +{%- endmacro %} +{% macro fallback_config(fallback_target) -%} + return radix_sort_block_sort_config_picker, key_type, value_type>(); +{%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/radixsort_onesweep_config_template b/projects/rocprim/scripts/autotune/templates/radixsort_onesweep_config_template index 591df8ef88b..3642304fce9 100644 --- a/projects/rocprim/scripts/autotune/templates/radixsort_onesweep_config_template +++ b/projects/rocprim/scripts/autotune/templates/radixsort_onesweep_config_template @@ -5,20 +5,25 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_RADIX_SORT_ONESWEEP_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -radix_sort_onesweep_config, -kernel_config<{{ measurement['cfg']['sort']['bs'] }}, {{ measurement['cfg']['sort']['ipt'] }}>, {{ measurement['cfg']['bits_per_place'] }}, {{ measurement['cfg']['algorithm'] }}> { }; +radix_sort_onesweep_config_params{kernel_config_params{ {{ measurement['cfg']['histogram']['bs'] }}, {{ measurement['cfg']['histogram']['ipt'] }} }, +kernel_config_params{ {{ measurement['cfg']['sort']['bs'] }}, {{ measurement['cfg']['sort']['ipt'] }} }, {{ measurement['cfg']['bits_per_place'] }}, {{ measurement['cfg']['algorithm'] }} } {%- endmacro %} -{% macro general_case() -%} -template struct default_radix_sort_onesweep_config : -radix_sort_onesweep_config_base::type { }; +{% macro config_picker() -%} +template constexpr auto radix_sort_onesweep_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_radix_sort_onesweep_config({{ benchmark_of_architecture.name }}), key_type, value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("radix_sort_onesweep_config_params") }} {%- endmacro %} +{% macro default_case() -%} +return radix_sort_onesweep_config_params_base(); +{%- endmacro %} - - +{% macro fallback_config(fallback_target) -%} + return radix_sort_onesweep_config_picker< + comp_target<{{ fallback_target.target.as_str() }}>, + key_type, + value_type>(); +{%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/reduce_by_key_config_template b/projects/rocprim/scripts/autotune/templates/reduce_by_key_config_template index e13ea2e5990..55fd225410f 100644 --- a/projects/rocprim/scripts/autotune/templates/reduce_by_key_config_template +++ b/projects/rocprim/scripts/autotune/templates/reduce_by_key_config_template @@ -5,26 +5,25 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_REDUCE_BY_KEY_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -reduce_by_key_config< - {{ measurement['cfg']['bs'] }}, - {{ measurement['cfg']['ipt'] }}, +reduce_by_key_config_params{ + { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, block_load_method::block_load_transpose, block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> { }; + block_scan_algorithm::using_warp_scan } {%- endmacro %} -{% macro general_case() -%} -template -struct default_reduce_by_key_config : default_reduce_by_key_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto reduce_by_key_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template -struct default_reduce_by_key_config< - static_cast({{ benchmark_of_architecture.name }}), - key_type, - value_type, - {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("reduce_by_key_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return reduce_by_key_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return reduce_by_key_config_picker, key_type, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/reduce_config_template b/projects/rocprim/scripts/autotune/templates/reduce_config_template index 3d3924cf39e..c28969387c5 100644 --- a/projects/rocprim/scripts/autotune/templates/reduce_config_template +++ b/projects/rocprim/scripts/autotune/templates/reduce_config_template @@ -5,19 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_REDUCE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -reduce_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, ::rocprim::block_reduce_algorithm::{{ measurement['cfg']['method'] }}> { }; +reduce_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, ::rocprim::block_reduce_algorithm::{{ measurement['cfg']['method'] }} } {%- endmacro %} -{% macro general_case() -%} -template struct default_reduce_config : -default_reduce_config_base::type { }; +{% macro config_picker() -%} +template constexpr auto reduce_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_reduce_config({{ benchmark_of_architecture.name }}), key_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("reduce_config_params") }} {%- endmacro %} - - +{% macro default_case() -%} +return reduce_config_params_base(); +{%- endmacro %} +{% macro fallback_config(fallback_target) -%} + return reduce_config_picker, key_type>(); +{%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/run_length_encode_config_template b/projects/rocprim/scripts/autotune/templates/run_length_encode_config_template index 17b2003d7da..e443a1f8662 100644 --- a/projects/rocprim/scripts/autotune/templates/run_length_encode_config_template +++ b/projects/rocprim/scripts/autotune/templates/run_length_encode_config_template @@ -5,26 +5,25 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_RUN_LENGTH_ENCODE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -reduce_by_key_config< - {{ measurement['cfg']['bs'] }}, - {{ measurement['cfg']['ipt'] }}, +reduce_by_key_config_params{ + { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, block_load_method::block_load_transpose, block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> { }; + block_scan_algorithm::using_warp_scan} {%- endmacro %} -{% macro general_case() -%} -template -struct default_trivial_runs_config : default_reduce_by_key_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto run_length_encode_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template -struct default_trivial_runs_config< - static_cast({{ benchmark_of_architecture.name }}), - key_type, - value_type, - {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("reduce_by_key_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return reduce_by_key_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return run_length_encode_config_picker, key_type, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/run_length_encode_non_trivial_runs_config_template b/projects/rocprim/scripts/autotune/templates/run_length_encode_non_trivial_runs_config_template index 49ce27a781c..e82a0225496 100644 --- a/projects/rocprim/scripts/autotune/templates/run_length_encode_non_trivial_runs_config_template +++ b/projects/rocprim/scripts/autotune/templates/run_length_encode_non_trivial_runs_config_template @@ -5,24 +5,24 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_RUN_LENGTH_ENCODE_NON_TRIVIAL_RUNS_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -non_trivial_runs_config< - {{ measurement['cfg']['bs'] }}, - {{ measurement['cfg']['ipt'] }}, +non_trivial_runs_config_params{ + { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, ::rocprim::{{ measurement['cfg']['load_method'] }}, - ::rocprim::block_scan_algorithm::using_warp_scan> { }; + ::rocprim::block_scan_algorithm::using_warp_scan} {%- endmacro %} -{% macro general_case() -%} -template -struct default_non_trivial_runs_config : default_non_trivial_runs_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto run_length_encode_non_trivial_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template -struct default_non_trivial_runs_config< - static_cast({{ benchmark_of_architecture.name }}), - key_type, - {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("non_trivial_runs_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return non_trivial_runs_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return run_length_encode_non_trivial_config_picker, key_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/scan_config_template b/projects/rocprim/scripts/autotune/templates/scan_config_template index 6de5a83ad95..3602f395b50 100644 --- a/projects/rocprim/scripts/autotune/templates/scan_config_template +++ b/projects/rocprim/scripts/autotune/templates/scan_config_template @@ -5,19 +5,23 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SCAN_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -scan_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, ::rocprim::block_load_method::block_load_transpose, ::rocprim::block_store_method::block_store_transpose, {{ measurement['cfg']['method'] }}> { }; +scan_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, ::rocprim::block_load_method::block_load_transpose, ::rocprim::block_store_method::block_store_transpose, {{ measurement['cfg']['method'] }} } {%- endmacro %} -{% macro general_case() -%} -template struct default_scan_config : -default_scan_config_base::type { }; +{% macro config_picker() -%} +template constexpr auto scan_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_scan_config({{ benchmark_of_architecture.name }}), value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("scan_config_params") }} {%- endmacro %} - - +{% macro default_case() -%} +return scan_config_params_base(); +{%- endmacro %} +{% macro fallback_config(fallback_target) -%} + return scan_config_picker< + comp_target<{{ fallback_target.target.as_str() }}>, + value_type>(); +{%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/scanbykey_config_template b/projects/rocprim/scripts/autotune/templates/scanbykey_config_template index 72653bb61a4..52138a1b966 100644 --- a/projects/rocprim/scripts/autotune/templates/scanbykey_config_template +++ b/projects/rocprim/scripts/autotune/templates/scanbykey_config_template @@ -5,19 +5,24 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SCAN_BY_KEY_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -scan_by_key_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, ::rocprim::block_load_method::block_load_transpose, ::rocprim::block_store_method::block_store_transpose, {{ measurement['cfg']['method'] }}> { }; +scan_by_key_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, ::rocprim::block_load_method::block_load_transpose, ::rocprim::block_store_method::block_store_transpose, {{ measurement['cfg']['method'] }} } {%- endmacro %} -{% macro general_case() -%} -template struct default_scan_by_key_config : -default_scan_by_key_config_base::type { }; +{% macro config_picker() -%} +template constexpr auto scan_by_key_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_scan_by_key_config({{ benchmark_of_architecture.name }}), key_type, value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("scan_by_key_config_params") }} {%- endmacro %} - - +{% macro default_case() -%} +return scan_by_key_config_params_base(); +{%- endmacro %} +{% macro fallback_config(fallback_target) -%} + return scan_by_key_config_picker< + comp_target<{{ fallback_target.target.as_str() }}>, + key_type, + value_type>(); +{%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/search_n_config_template b/projects/rocprim/scripts/autotune/templates/search_n_config_template index 7ad5393b1e7..919b5c29b03 100644 --- a/projects/rocprim/scripts/autotune/templates/search_n_config_template +++ b/projects/rocprim/scripts/autotune/templates/search_n_config_template @@ -5,15 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEARCH_N_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -search_n_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, {{ measurement['cfg']['threshold'] }}> { }; +search_n_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, {{ measurement['cfg']['threshold'] }} } {%- endmacro %} -{% macro general_case() -%} -template struct default_search_n_config : -default_search_n_config_base::type { }; +{% macro config_picker() -%} +template constexpr auto search_n_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_search_n_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("search_n_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return search_n_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return search_n_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/segmented_radix_sort_config_template b/projects/rocprim/scripts/autotune/templates/segmented_radix_sort_config_template index 437718b9afd..8e625904636 100644 --- a/projects/rocprim/scripts/autotune/templates/segmented_radix_sort_config_template +++ b/projects/rocprim/scripts/autotune/templates/segmented_radix_sort_config_template @@ -5,28 +5,37 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEGMENTED_RADIX_SORT_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -segmented_radix_sort_config< +segmented_radix_sort_config_params{ {{ measurement['cfg']['rb'] }}, - kernel_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}>, - typename std::conditional< - {{ measurement['cfg']['wsc']['pa'] }}, - WarpSortConfig< + kernel_config_params{ {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, +{% if measurement['cfg']['wsc']['pa'] == 1 -%} + warp_sort_config_params{ + {{ measurement['cfg']['wsc']['pa'] }}, {{ measurement['cfg']['wsc']['lwss'] }}, {{ measurement['cfg']['wsc']['ipts'] }}, {{ measurement['cfg']['wsc']['bss'] }}, {{ measurement['cfg']['wsc']['pt'] }}, {{ measurement['cfg']['wsc']['lwsm'] }}, {{ measurement['cfg']['wsc']['iptm'] }}, - {{ measurement['cfg']['wsc']['bsm'] }}>, - DisabledWarpSortConfig - >::type, - {{ measurement['cfg']['eupws'] }} > { }; + {{ measurement['cfg']['wsc']['bsm'] }} }, +{% else -%} + warp_sort_config_params{0}, +{% endif -%} + {{ measurement['cfg']['eupws'] }} } {%- endmacro %} -{% macro general_case() -%} -template -struct default_segmented_radix_sort_config : default_segmented_radix_sort_config_base<6>::type -{}; +{% macro config_picker() -%} +template constexpr auto segmented_radix_sort_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_segmented_radix_sort_config({{ benchmark_of_architecture.name }}), key_type, value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("segmented_radix_sort_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return segmented_radix_sort_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return segmented_radix_sort_config_picker< + comp_target<{{ fallback_target.target.as_str() }}>, + key_type, + value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/segmented_reduce_config_template b/projects/rocprim/scripts/autotune/templates/segmented_reduce_config_template index 3b6ed10bcf3..acce73162a5 100644 --- a/projects/rocprim/scripts/autotune/templates/segmented_reduce_config_template +++ b/projects/rocprim/scripts/autotune/templates/segmented_reduce_config_template @@ -5,15 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEGMENTED_REDUCE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -reduce_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, ::rocprim::block_reduce_algorithm::{{ measurement['cfg']['method'] }}> { }; +reduce_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, ::rocprim::block_reduce_algorithm::{{ measurement['cfg']['method'] }} } {%- endmacro %} -{% macro general_case() -%} -template struct default_segmented_reduce_config : -default_reduce_config_base::type { }; +{% macro config_picker() -%} +template constexpr auto segmented_reduce_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_segmented_reduce_config({{ benchmark_of_architecture.name }}), key_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("reduce_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return reduce_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return segmented_reduce_config_picker, key_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/select_flag_config_template b/projects/rocprim/scripts/autotune/templates/select_flag_config_template index 36a7055e485..80f8a8aba18 100644 --- a/projects/rocprim/scripts/autotune/templates/select_flag_config_template +++ b/projects/rocprim/scripts/autotune/templates/select_flag_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_FLAG_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_select_flag_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto select_flag_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_select_flag_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return select_flag_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/select_predicate_config_template b/projects/rocprim/scripts/autotune/templates/select_predicate_config_template index 137ba5ee387..b5450975e32 100644 --- a/projects/rocprim/scripts/autotune/templates/select_predicate_config_template +++ b/projects/rocprim/scripts/autotune/templates/select_predicate_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_PREDICATE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_select_predicate_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto select_predicate_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_select_predicate_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return select_predicate_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/select_predicated_flag_config_template b/projects/rocprim/scripts/autotune/templates/select_predicated_flag_config_template index 2ba3565aefd..7907d92b2df 100644 --- a/projects/rocprim/scripts/autotune/templates/select_predicated_flag_config_template +++ b/projects/rocprim/scripts/autotune/templates/select_predicated_flag_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_PREDICATED_FLAG_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_select_predicated_flag_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto select_predicated_flag_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_select_predicated_flag_config({{ benchmark_of_architecture.name }}), data_type, flag_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return select_predicated_flag_config_picker, data_type, flag_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/select_unique_by_key_config_template b/projects/rocprim/scripts/autotune/templates/select_unique_by_key_config_template index 2177f95af1d..200b2de7b2f 100644 --- a/projects/rocprim/scripts/autotune/templates/select_unique_by_key_config_template +++ b/projects/rocprim/scripts/autotune/templates/select_unique_by_key_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_UNIQUE_BY_KEY_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_select_unique_by_key_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto select_unique_by_key_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_select_unique_by_key_config({{ benchmark_of_architecture.name }}), key_type, value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return select_unique_by_key_config_picker, key_type, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/select_unique_config_template b/projects/rocprim/scripts/autotune/templates/select_unique_config_template index 9bb8aec9f34..9c005946dca 100644 --- a/projects/rocprim/scripts/autotune/templates/select_unique_config_template +++ b/projects/rocprim/scripts/autotune/templates/select_unique_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_UNIQUE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_select_unique_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto select_unique_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_select_unique_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return select_unique_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/transform_config_template b/projects/rocprim/scripts/autotune/templates/transform_config_template index bc9e4fe8107..fbbef6af1b5 100644 --- a/projects/rocprim/scripts/autotune/templates/transform_config_template +++ b/projects/rocprim/scripts/autotune/templates/transform_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_TRANSFORM_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -transform_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +transform_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_transform_config : default_transform_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto transform_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_transform_config({{ benchmark_of_architecture.name }}), value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("transform_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return transform_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return transform_config_picker, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/transform_pointer_config_template b/projects/rocprim/scripts/autotune/templates/transform_pointer_config_template index efa1193dff1..46ad6667241 100644 --- a/projects/rocprim/scripts/autotune/templates/transform_pointer_config_template +++ b/projects/rocprim/scripts/autotune/templates/transform_pointer_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_TRANSFORM_POINTER_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -transform_pointer_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::{{ measurement['cfg']['lt'] }}> { }; +transform_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, ::rocprim::{{ measurement['cfg']['lt'] }} } {%- endmacro %} -{% macro general_case() -%} -template -struct default_transform_pointer_config : default_transform_pointer_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto transform_pointer_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_transform_pointer_config({{ benchmark_of_architecture.name }}), value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("transform_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return transform_pointer_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return transform_pointer_config_picker, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/upper_bound_config_template b/projects/rocprim/scripts/autotune/templates/upper_bound_config_template index 05d05c4de98..9bd9c8a83c3 100644 --- a/projects/rocprim/scripts/autotune/templates/upper_bound_config_template +++ b/projects/rocprim/scripts/autotune/templates/upper_bound_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_UPPER_BOUND_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -upper_bound_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +transform_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_upper_bound_config : default_binary_search_config_base -{}; +{% macro config_picker() -%} +template constexpr auto upper_bound_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_upper_bound_config({{ benchmark_of_architecture.name }}), value_type, output_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("transform_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return binary_search_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return upper_bound_config_picker, value_type, output_type>(); {%- endmacro %} diff --git a/projects/rocprim/test/rocprim/test_config_dispatch.cpp b/projects/rocprim/test/rocprim/test_config_dispatch.cpp index 29de22eb0e6..7d39aaf5081 100644 --- a/projects/rocprim/test/rocprim/test_config_dispatch.cpp +++ b/projects/rocprim/test/rocprim/test_config_dispatch.cpp @@ -36,7 +36,14 @@ void write_target_arch([[maybe_unused]] target_arch host_arch, int* __restrict__ #if !defined(ROCPRIM_TARGET_SPIRV) static constexpr auto arch = rocprim::detail::device_target_arch(); - *result = arch == host_arch; + if constexpr(!ROCPRIM_IS_GENERIC()) + { + *result = arch == host_arch; + } + else + { + *result = -2; + } #else *result = -1; #endif @@ -94,8 +101,15 @@ TEST(RocprimConfigDispatchTests, HostMatchesDevice) if(result != -1) { - ASSERT_NE(host_arch, target_arch::invalid); - ASSERT_EQ(result, 1); + if(result != -2) + { + ASSERT_NE(host_arch, target_arch::invalid); + ASSERT_EQ(result, 1); + } + else + { + GTEST_SKIP() << "Generic build: result is null; skipping arch match assertion."; + } } else { @@ -160,3 +174,289 @@ TEST(RocprimConfigDispatchTests, DeviceIdFromStream) ASSERT_EQ(result, device_id); } #endif + +namespace rocprim +{ +namespace detail +{ + +using Targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; + +struct test_config_params +{ + kernel_config_params kernel_config{}; + + __host__ __device__ + inline bool + operator==(test_config_params other) const + { + return (kernel_config.block_size == other.kernel_config.block_size) + && (kernel_config.items_per_thread == other.kernel_config.items_per_thread); + } +}; + +template +struct TestSelector +{ + using targets = TargetsInput; + using param_type = test_config_params; + + param_type params; + + template + constexpr TestSelector(Target) + : params(test_config_params{ + {BlockSize, ItemsPerThread} + }) + {} +}; + +template +struct TestSelector2 +{ + using targets = TargetsInput; + using param_type = test_config_params; + + param_type params; + + template + constexpr param_type picker_helper() + { + constexpr auto arch = Target::i; + constexpr auto gen = Target::g; + constexpr auto gpu = Target::s; + + // Assign unique configs per architecture/gen/GPU + if constexpr(arch == target_arch::gfx1030 && gen == gen::rdna2 && gpu == gpu::rx6900) + { + return param_type{ + {64, 1} + }; + } + else if constexpr(arch == target_arch::gfx1100 && gen == gen::rdna3 && gpu == gpu::rx7900) + { + return param_type{ + {128, 2} + }; + } + else if constexpr(arch == target_arch::gfx908 && gen == gen::cdna1 && gpu == gpu::mi100) + { + return param_type{ + {192, 3} + }; + } + else if constexpr(arch == target_arch::gfx90a && gen == gen::cdna2 && gpu == gpu::mi210) + { + return param_type{ + {256, 4} + }; + } + else if constexpr(arch == target_arch::gfx942 && gen == gen::cdna3 && gpu == gpu::mi300x) + { + return param_type{ + {320, 5} + }; + } + else if constexpr(arch == target_arch::gfx942 && gen == gen::cdna3 && gpu == gpu::mi300a) + { + return param_type{ + {384, 6} + }; + } + else if constexpr(arch == target_arch::gfx1200 && gen == gen::rdna4 && gpu == gpu::rx9060) + { + return param_type{ + {448, 7} + }; + } + else if constexpr(arch == target_arch::gfx1201 && gen == gen::rdna4 && gpu == gpu::rx9070) + { + return param_type{ + {512, 8} + }; + } + else if constexpr(arch == target_arch::gfx950 && gen == gen::cdna4 && gpu == gpu::mi350x) + { + return param_type{ + {576, 9} + }; + } + else + { + // Default fallback for unknown targets + return param_type{ + {1, 1} + }; + } + } + + template + constexpr TestSelector2(Target) : params(picker_helper()) + {} +}; + +} // namespace detail +} // namespace rocprim + +// This test needs to be changed once the selection logic changes for most_common_config. +TEST(RocprimConfigDispatchTests, MostCommonConfig) +{ + using namespace rocprim::detail; + + // 1. Default-constructed target (all unknowns) should return itself. + ASSERT_EQ(most_common_config(target()), target()); + // 2. Exact match (same arch & gpu) should return that target. + ASSERT_EQ(most_common_config(target(target_arch::gfx1200, gpu::rx9060)), + target(target_arch::gfx1200, gpu::rx9060)); + // 3. Unknown arch/gpu combination should return default unknown. + ASSERT_EQ(most_common_config(target(target_arch::gfx906, gpu::mi50)), target()); + // 4. Multiple entries with same arch — latest wins (mi300a over mi300x). + ASSERT_EQ(most_common_config(target(target_arch::gfx942, gpu::mi308x)), + target(target_arch::gfx942, gpu::mi300a)); + // 5. Target with same generation but unknown arch — picks matching gen target. + ASSERT_EQ(most_common_config(target(target_arch::gfx1102)), + target(target_arch::gfx1100, gpu::rx7900)); + // 6. Generation-only match (no arch/gpu known). + ASSERT_EQ(most_common_config(target(gen::rdna2)), + target(target_arch::gfx1030, gpu::rx6900)); + // 7. Arch-only match (gpu::generic, known arch). + ASSERT_EQ(most_common_config(target(target_arch::gfx90a)), + target(target_arch::gfx90a, gpu::mi210)); + // 8. If gen has multiple matching targets (cdna3), latest listed is chosen. + ASSERT_EQ(most_common_config(target(gen::cdna3)), + target(target_arch::gfx942, gpu::mi300a)); +} + +// This test needs to be changed once the selection logic changes for most_common_config. +TEST(RocprimConfigDispatchTests, DefaultSelectConfig) +{ + using namespace rocprim::detail; + + using Selector = TestSelector2; + using Params = typename Selector::param_type; + + auto cfg = [](target t) { return default_select_config(t); }; + + // 1. Default target (unknown) + ASSERT_EQ(cfg(target()), + (Params{ + {1, 1} + })); + + // 2. Exact match (gfx1200, rx9060) + ASSERT_EQ(cfg(target(target_arch::gfx1200, gpu::rx9060)), + (Params{ + {448, 7} + })); + + // 3. Unknown arch/gpu + ASSERT_EQ(cfg(target(target_arch::gfx906, gpu::mi50)), + (Params{ + {1, 1} + })); + + // 4. Multiple entries same arch (gfx942) — latest wins (mi300a) + ASSERT_EQ(cfg(target(target_arch::gfx942, gpu::mi308x)), + (Params{ + {384, 6} + })); + + // 5. Same gen but unknown arch (gfx1102) + ASSERT_EQ(cfg(target(target_arch::gfx1102)), + (Params{ + {128, 2} + })); + + // 6. Generation-only match (rdna2) + ASSERT_EQ(cfg(target(gen::rdna2)), + (Params{ + {64, 1} + })); + + // 7. Arch-only match (gfx90a) + ASSERT_EQ(cfg(target(target_arch::gfx90a)), + (Params{ + {256, 4} + })); + + // 8. Generation with multiple matches (cdna3 -> gfx942/mi300a) + ASSERT_EQ(cfg(target(gen::cdna3)), + (Params{ + {384, 6} + })); + + // 9. Other targets in the list + ASSERT_EQ(cfg(target(target_arch::gfx1201, gpu::rx9070)), + (Params{ + {512, 8} + })); + ASSERT_EQ(cfg(target(target_arch::gfx950, gpu::mi350x)), + (Params{ + {576, 9} + })); +} + +TEST(RocprimConfigDispatchTests, ExecuteLaunchPlan) +{ + using namespace rocprim::detail; + using EmptyTargets + = comp_targets>; + using Config = rocprim::default_config; + + constexpr unsigned int block_size = 256; + constexpr unsigned int ipt = 2; + + hipStream_t stream = 0; + + target_arch target_arch; + HIP_CHECK(host_target_arch(stream, target_arch)); + gpu target_gpu; + HIP_CHECK(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + target* d_output; + HIP_CHECK(hipMalloc(&d_output, sizeof(target))); + + auto kernel = [=](auto arch_config) { *d_output = decltype(arch_config)::config_target; }; + + HIP_CHECK((execute_launch_plan, + default_config_static_selector>(current_target, + kernel, + dim3(1), + dim3(block_size), + 0, + stream))); + + // Impossible target + target h_output{gen::cdna4, target_arch::gfx942, gpu::rx6900, rep::spirv}; + HIP_CHECK(hipMemcpy(&h_output, d_output, sizeof(target), hipMemcpyDeviceToHost)); + // Compared to targets with only unknown inside. + ASSERT_EQ(target(), h_output); + + HIP_CHECK((execute_launch_plan, + default_config_static_selector>(current_target, + kernel, + dim3(1), + dim3(block_size), + 0, + stream))); + + // Impossible target + h_output = target{gen::cdna4, target_arch::gfx942, gpu::rx6900, rep::spirv}; + HIP_CHECK(hipMemcpy(&h_output, d_output, sizeof(target), hipMemcpyDeviceToHost)); + // Should have the same targets as most_common_config. + ASSERT_EQ(most_common_config(current_target), h_output); +} diff --git a/projects/rocprim/test/rocprim/test_device_batch_memcpy.cpp b/projects/rocprim/test/rocprim/test_device_batch_memcpy.cpp index 48060147c94..d5009f89d97 100644 --- a/projects/rocprim/test/rocprim/test_device_batch_memcpy.cpp +++ b/projects/rocprim/test/rocprim/test_device_batch_memcpy.cpp @@ -256,15 +256,20 @@ TYPED_TEST(RocprimDeviceBatchMemcpyTests, SizeAndTypeVariation) constexpr bool use_indirect_iterator = TestFixture::use_indirect_iterator; constexpr bool debug_synchronous = TestFixture::debug_synchronous; - using config = rocprim::detail:: - wrapped_batch_memcpy_config; + using Selector = rocprim::detail::batch_memcpy_config_selector; rocprim::detail::target_arch target_arch; - hipError_t success = rocprim::detail::host_target_arch(hipStreamDefault, target_arch); + hipError_t success = host_target_arch(hipStreamDefault, target_arch); + + rocprim::detail::gpu target_gpu; + success = host_target_gpu(hipStreamDefault, target_gpu); + ASSERT_EQ(success, hipSuccess); - const rocprim::detail::batch_memcpy_config_params params - = rocprim::detail::dispatch_target_arch(target_arch); + const rocprim::detail::target get_target(target_arch, target_gpu); + + const auto params + = rocprim::detail::get_config(rocprim::default_config{}, get_target); const int32_t wlev_min_size = params.wlev_size_threshold; const int32_t blev_min_size = params.blev_size_threshold; diff --git a/projects/rocprim/test/rocprim/test_device_scan.cpp b/projects/rocprim/test/rocprim/test_device_scan.cpp index 8327b89fe63..aecb900a1c1 100644 --- a/projects/rocprim/test/rocprim/test_device_scan.cpp +++ b/projects/rocprim/test/rocprim/test_device_scan.cpp @@ -305,15 +305,20 @@ TYPED_TEST(RocprimDeviceScanTests, LookBackScan) const bool deterministic = TestFixture::deterministic; const bool use_initial_value = TestFixture::use_initial_value; - using Config = typename TestFixture::config_helper; - using config = rocprim::detail::wrapped_scan_config; + using Config = typename TestFixture::config_helper; + using Selector = rocprim::detail::scan_config_selector; hipStream_t stream = hipStreamDefault; rocprim::detail::target_arch target_arch; - HIP_CHECK(host_target_arch(stream, target_arch)); - const rocprim::detail::scan_config_params params - = rocprim::detail::dispatch_target_arch(target_arch); + HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); + + rocprim::detail::gpu target_gpu; + HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); + + const rocprim::detail::target current_target(target_arch, target_gpu); + + const auto params = rocprim::detail::get_config(Config{}, current_target); // For non-associative operations in inclusive scan // intermediate results use the type of input iterator, then @@ -497,12 +502,13 @@ TYPED_TEST(RocprimDeviceScanTests, LookBackScan) false, ordered_bid); }; - return rocprim::detail::execute_launch_plan(target_arch, - lookback_scan_kernel, - dim3(grid_size), - dim3(block_size), - 0, - stream); + return rocprim::detail::execute_launch_plan( + current_target, + lookback_scan_kernel, + dim3(grid_size), + dim3(block_size), + 0, + stream); }); ASSERT_EQ(hipSuccess, launch_err); @@ -553,15 +559,20 @@ TYPED_TEST(RocprimDeviceScanTests, LookBackScanGetCompleteValue) const bool deterministic = TestFixture::deterministic; const bool use_initial_value = TestFixture::use_initial_value; - using Config = typename TestFixture::config_helper; - using config = rocprim::detail::wrapped_scan_config; + using Config = typename TestFixture::config_helper; + using Selector = rocprim::detail::scan_config_selector; hipStream_t stream = hipStreamDefault; rocprim::detail::target_arch target_arch; - HIP_CHECK(host_target_arch(stream, target_arch)); - const rocprim::detail::scan_config_params params - = rocprim::detail::dispatch_target_arch(target_arch); + HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); + + rocprim::detail::gpu target_gpu; + HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); + + const rocprim::detail::target current_target(target_arch, target_gpu); + + const auto params = rocprim::detail::get_config(Config{}, current_target); // For non-associative operations in inclusive scan // intermediate results use the type of input iterator, then @@ -735,12 +746,12 @@ TYPED_TEST(RocprimDeviceScanTests, LookBackScanGetCompleteValue) false, ordered_bid); }; - return rocprim::detail::execute_launch_plan(target_arch, - lookback_scan_kernel, - dim3(grid_size), - dim3(block_size), - 0, - stream); + return rocprim::detail::execute_launch_plan(current_target, + lookback_scan_kernel, + dim3(grid_size), + dim3(block_size), + 0, + stream); }); ASSERT_EQ(hipSuccess, launch_err); diff --git a/projects/rocprim/test/rocprim/test_device_search_n.cpp b/projects/rocprim/test/rocprim/test_device_search_n.cpp index c7622e319fc..1b8eee34429 100644 --- a/projects/rocprim/test/rocprim/test_device_search_n.cpp +++ b/projects/rocprim/test/rocprim/test_device_search_n.cpp @@ -904,13 +904,18 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_1block) for(const auto size : test_utils::get_sizes(seed_value)) { - using wrapped_config = rocprim::detail::wrapped_search_n_config; - size_t temp_storage_size; - hipStream_t stream = 0; // default + using Selector = rocprim::detail::search_n_config_selector; + size_t temp_storage_size; + hipStream_t stream = 0; // default + rocprim::detail::target_arch target_arch; HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); - const auto params - = rocprim::detail::dispatch_target_arch(target_arch); + rocprim::detail::gpu target_gpu; + HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); + + const rocprim::detail::target current_target(target_arch, target_gpu); + + const auto params = rocprim::detail::get_config(config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -1025,13 +1030,18 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_2block) for(const auto size : test_utils::get_sizes(seed_value)) { - using wrapped_config = rocprim::detail::wrapped_search_n_config; - size_t temp_storage_size; - hipStream_t stream = 0; // default + using Selector = rocprim::detail::search_n_config_selector; + size_t temp_storage_size; + hipStream_t stream = 0; // default + rocprim::detail::target_arch target_arch; HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); - const auto params - = rocprim::detail::dispatch_target_arch(target_arch); + rocprim::detail::gpu target_gpu; + HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); + + const rocprim::detail::target current_target(target_arch, target_gpu); + + const auto params = rocprim::detail::get_config(config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -1146,13 +1156,18 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_3block) for(const auto size : test_utils::get_sizes(seed_value)) { - using wrapped_config = rocprim::detail::wrapped_search_n_config; - size_t temp_storage_size; - hipStream_t stream = 0; // default + using Selector = rocprim::detail::search_n_config_selector; + size_t temp_storage_size; + hipStream_t stream = 0; // default + rocprim::detail::target_arch target_arch; HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); - const auto params - = rocprim::detail::dispatch_target_arch(target_arch); + rocprim::detail::gpu target_gpu; + HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); + + const rocprim::detail::target current_target(target_arch, target_gpu); + + const auto params = rocprim::detail::get_config(config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -1267,13 +1282,18 @@ TYPED_TEST(RocprimDeviceSearchNTests, MultiResult1) for(const auto size : test_utils::get_sizes(seed_value)) { - using wrapped_config = rocprim::detail::wrapped_search_n_config; - size_t temp_storage_size; - hipStream_t stream = 0; // default + using Selector = rocprim::detail::search_n_config_selector; + size_t temp_storage_size; + hipStream_t stream = 0; // default + rocprim::detail::target_arch target_arch; HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); - const auto params - = rocprim::detail::dispatch_target_arch(target_arch); + rocprim::detail::gpu target_gpu; + HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); + + const rocprim::detail::target current_target(target_arch, target_gpu); + + const auto params = rocprim::detail::get_config(config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -1389,13 +1409,18 @@ TYPED_TEST(RocprimDeviceSearchNTests, MultiResult2) = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; for(const auto size : test_utils::get_sizes(seed_value)) { - using wrapped_config = rocprim::detail::wrapped_search_n_config; - size_t temp_storage_size; - hipStream_t stream = 0; // default + using Selector = rocprim::detail::search_n_config_selector; + size_t temp_storage_size; + hipStream_t stream = 0; // default + rocprim::detail::target_arch target_arch; HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); - const auto params - = rocprim::detail::dispatch_target_arch(target_arch); + rocprim::detail::gpu target_gpu; + HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); + + const rocprim::detail::target current_target(target_arch, target_gpu); + + const auto params = rocprim::detail::get_config(config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; diff --git a/projects/rocprim/test/rocprim/test_linking_new_scan.hpp b/projects/rocprim/test/rocprim/test_linking_new_scan.hpp index ef51bcb30ea..a38953bc99e 100644 --- a/projects/rocprim/test/rocprim/test_linking_new_scan.hpp +++ b/projects/rocprim/test/rocprim/test_linking_new_scan.hpp @@ -98,15 +98,17 @@ inline auto scan_impl(void* temporary_storage, { (void)debug_synchronous; - using config = wrapped_scan_config; + using Selector = scan_config_selector; - detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const scan_config_params params = dispatch_target_arch(target_arch); + rocprim::detail::target_arch target_arch; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + rocprim::detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -133,12 +135,12 @@ inline auto scan_impl(void* temporary_storage, output, scan_op); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - single_scan_kernel, - dim3(1), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + single_scan_kernel, + dim3(1), + dim3(block_size), + 0, + stream)); return hipGetLastError(); }